From 46346a779d24d306cfd7e35c2244d01e39fadbf2 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sat, 2 Aug 2025 08:11:43 +0300 Subject: [PATCH 01/74] WIP LLM pipeline and dataset implementation --- WORKSPACE | 34 +- flutter/android/android-docker.mk | 4 +- flutter/assets/icons/ic_task_llm.svg | 99 + flutter/assets/tasks.pbtxt | 27 + flutter/cpp/binary/BUILD | 1 + flutter/cpp/binary/main.cc | 25 +- flutter/cpp/datasets/BUILD | 31 + flutter/cpp/datasets/mmlu_gen.cc | 78 + flutter/cpp/datasets/mmlu_gen.h | 61 + .../datasets/mmlu_utils/generate_tfrecords.py | 37 + flutter/cpp/flutter/BUILD | 1 + flutter/cpp/proto/mlperf_task.proto | 1 + flutter/third_party/BUILD | 247 ++ ...-in-tensorflow-lite-tools-evaluation.patch | 9 +- mobile_back_tflite/cpp/backend_tflite/BUILD | 8 + .../tflite_settings_android.pbtxt | 23 + .../cpp/backend_tflite/llm_pipeline.cc | 380 +++ .../cpp/backend_tflite/llm_pipeline.h | 159 ++ .../backend_tflite/single_model_pipeline.cc | 2 +- patches/com_google_sentencepiece.diff | 2357 +++++++++++++++++ patches/darts_clone.BUILD | 12 + patches/darts_no_exceptions.diff | 87 + patches/ndk_25_r14.diff | 184 -- patches/sentencepiece.BUILD | 165 ++ 24 files changed, 3832 insertions(+), 200 deletions(-) create mode 100644 flutter/assets/icons/ic_task_llm.svg create mode 100644 flutter/cpp/datasets/mmlu_gen.cc create mode 100644 flutter/cpp/datasets/mmlu_gen.h create mode 100644 flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py create mode 100644 flutter/third_party/BUILD create mode 100644 mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc create mode 100644 mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h create mode 100644 patches/com_google_sentencepiece.diff create mode 100644 patches/darts_clone.BUILD create mode 100644 patches/darts_no_exceptions.diff delete mode 100644 patches/ndk_25_r14.diff create mode 100644 patches/sentencepiece.BUILD diff --git a/WORKSPACE b/WORKSPACE index dbcbc5c2f..cc6b12164 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -38,17 +38,39 @@ http_archive( "//:flutter/third_party/use_unsigned_char.patch", # Fix tensorflow not being able to read image files on Windows "//:flutter/third_party/tensorflow-fix-file-opening-mode-for-Windows.patch", - "//:flutter/third_party/tf-eigen.patch", - # NDK 25 support - "//patches:ndk_25_r14.diff", + #"//:flutter/third_party/tf-eigen.patch", ] + PATCH_FILE, - sha256 = "ce357fd0728f0d1b0831d1653f475591662ec5bca736a94ff789e6b1944df19f", - strip_prefix = "tensorflow-2.14.0", + sha256 = "9cc4d5773b8ee910079baaecb4086d0c28939f024dd74b33fc5e64779b6533dc", + strip_prefix = "tensorflow-2.17.0", urls = [ - "https://github.com/tensorflow/tensorflow/archive/v2.14.0.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/v2.17.0.tar.gz", ], ) +http_archive( + name = "com_google_sentencepiece", + strip_prefix = "sentencepiece-0.1.96", + sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754", + urls = [ + "https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip" + ], + build_file = "@//patches:sentencepiece.BUILD", + patches = ["@//patches:com_google_sentencepiece.diff"], + patch_args = ["-p1"], +) + +http_archive( + name = "darts_clone", + sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", + strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983", + urls = [ + "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", + ], + build_file = "@//patches:darts_clone.BUILD", + patches = ["//patches:darts_no_exceptions.diff"], + patch_args = ["-p0"], +) + load( "@org_tensorflow//tensorflow/tools/toolchains/python:python_repo.bzl", "python_repository", diff --git a/flutter/android/android-docker.mk b/flutter/android/android-docker.mk index 6c5c55f4a..2e3d0092a 100644 --- a/flutter/android/android-docker.mk +++ b/flutter/android/android-docker.mk @@ -19,7 +19,7 @@ user_id=$(shell id -u) .PHONY: flutter/android/docker/image flutter/android/docker/image: output/docker/mlperf_mobile_flutter_android_${user_id}.stamp output/docker/mlperf_mobile_flutter_android_${user_id}.stamp: flutter/android/docker/Dockerfile - docker image build -t ${DOCKER_IMAGE_TAG} flutter/android/docker + DOCKER_BUILDKIT=1 docker buildx build --tag ${DOCKER_IMAGE_TAG} flutter/android/docker mkdir -p output/docker touch $@ @@ -68,4 +68,4 @@ docker/flutter/android/release: flutter/check-release-env flutter/android/docker docker/flutter/clean: flutter/check-release-env MSYS2_ARG_CONV_EXCL="*" docker run \ ${flutter_common_docker_flags} \ - make flutter/clean \ No newline at end of file + make flutter/clean diff --git a/flutter/assets/icons/ic_task_llm.svg b/flutter/assets/icons/ic_task_llm.svg new file mode 100644 index 000000000..ffede9c72 --- /dev/null +++ b/flutter/assets/icons/ic_task_llm.svg @@ -0,0 +1,99 @@ + + + + + + + + + + + + + + + + + diff --git a/flutter/assets/tasks.pbtxt b/flutter/assets/tasks.pbtxt index 12e573973..8cb92d03e 100644 --- a/flutter/assets/tasks.pbtxt +++ b/flutter/assets/tasks.pbtxt @@ -336,6 +336,33 @@ task { } } +task { + id: "llm" + name: "LLM" + max_throughput: 2000 + max_accuracy: 1.0 + scenario: "SingleStream" + runs { + normal { + min_query_count: 1024 + min_duration: 60 + max_duration: 300 + } + } + datasets { + type: MMLU + tiny { + name: "TinyMMLU prompt set for LLM" + input_path: "https://thee.dev/mlc/data.tfrecord" #TODO placeholder + input_checksum: "b564d2c228a867148fa7d6df415a0368" + } + } + model { + id: "LLM" + name: "LLM" + } +} + task { id: "stable_diffusion" name: "Stable Diffusion" diff --git a/flutter/cpp/binary/BUILD b/flutter/cpp/binary/BUILD index ec07751e4..595421343 100644 --- a/flutter/cpp/binary/BUILD +++ b/flutter/cpp/binary/BUILD @@ -55,6 +55,7 @@ cc_binary( "//flutter/cpp/datasets:coco", "//flutter/cpp/datasets:coco_gen", "//flutter/cpp/datasets:imagenet", + "//flutter/cpp/datasets:mmlu_gen", "//flutter/cpp/datasets:snu_sr", "//flutter/cpp/datasets:squad", "//flutter/cpp/proto:mlperf_task_cc_proto", diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index 89f5758ad..aadccf855 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -26,6 +26,7 @@ limitations under the License. #include "flutter/cpp/datasets/coco.h" #include "flutter/cpp/datasets/coco_gen.h" #include "flutter/cpp/datasets/imagenet.h" +#include "flutter/cpp/datasets/mmlu_gen.h" #include "flutter/cpp/datasets/snu_sr.h" #include "flutter/cpp/datasets/squad.h" #include "flutter/cpp/mlperf_driver.h" @@ -67,6 +68,8 @@ DatasetConfig::DatasetType Str2DatasetType(absl::string_view name) { return DatasetConfig::SNUSR; } else if (absl::EqualsIgnoreCase(name, "COCOGEN")) { return DatasetConfig::COCOGEN; + } else if (absl::EqualsIgnoreCase(name, "MMLU")) { + return DatasetConfig::MMLU; } else if (absl::EqualsIgnoreCase(name, "DUMMY")) { return DatasetConfig::NONE; } else { @@ -88,6 +91,8 @@ DatasetConfig::DatasetType BenchmarkId2DatasetType(absl::string_view name) { return DatasetConfig::SNUSR; } else if (absl::StartsWith(name, "stable_diffusion")) { return DatasetConfig::COCOGEN; + } else if (absl::StartsWith(name, "llm")) { + return DatasetConfig::MMLU; } else { LOG(FATAL) << "Unrecognized benchmark_id: " << name; return DatasetConfig::NONE; @@ -113,7 +118,7 @@ int Main(int argc, char *argv[]) { "Benchmark ID. One of image_classification, " "image_classification_v2, object_detection, " "natural_language_processing, " - "image_segmentation_v2, super_resolution, stable_diffusion, " + "image_segmentation_v2, super_resolution, stable_diffusion, LLM, " "image_classification_offline, image_classification_offline_v2", Flag::kPositional)}; Flags::Parse(&argc, const_cast(argv), flag_list); @@ -389,6 +394,24 @@ int Main(int argc, char *argv[]) { flag_list.insert(flag_list.end(), dataset_flags.begin(), dataset_flags.end()); } break; + case DatasetConfig::MMLU: { + LOG(INFO) << "TinyMMLU dataset for LLM benchmark"; + std::string input_tfrecord, input_clip_model = ""; + std::vector dataset_flags{ + Flag::CreateFlag( + "input_tfrecord", &input_tfrecord, + "Path to the tfrecord file containing inputs for the model.", + Flag::kRequired), + }; + + if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && + backend) { + dataset.reset(new MmluGen(backend.get(), input_tfrecord)); + } + // Adds to flag_list for showing help. + flag_list.insert(flag_list.end(), dataset_flags.begin(), + dataset_flags.end()); + } break; case DatasetConfig::NONE: default: break; diff --git a/flutter/cpp/datasets/BUILD b/flutter/cpp/datasets/BUILD index 5e44ee59a..ebcd71587 100644 --- a/flutter/cpp/datasets/BUILD +++ b/flutter/cpp/datasets/BUILD @@ -204,3 +204,34 @@ cc_library( "@org_tensorflow//tensorflow/lite/tools/evaluation/stages:object_detection_average_precision_stage", ], ) + +cc_library( + name = "mmlu_gen", + srcs = [ + "mmlu_gen.cc", + ], + hdrs = [ + "mmlu_gen.h", + "utils.h", + ], + copts = tflite_copts() + select({ + "//flutter/android/commonlibs:use_asan": [ + "-fsanitize=address", + "-g", + "-O1", + "-fno-omit-frame-pointer", + ], + "//conditions:default": [], + }), + deps = [ + ":allocator", + "//flutter/cpp:mlperf_driver", + "//flutter/cpp:utils", + "//flutter/cpp/backends:external", + "//flutter/cpp/datasets/squad_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_protobuf//:protobuf", + "@org_tensorflow//tensorflow/lite/tools/evaluation:utils", + "@org_tensorflow//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + ], +) diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc new file mode 100644 index 000000000..65a39717f --- /dev/null +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -0,0 +1,78 @@ +#include "flutter/cpp/datasets/mmlu_gen.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature_util.h" + +#include +#include + +namespace mlperf { +namespace mobile { + +MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) + : sample_reader_(input_tfrecord), Dataset(backend) { + // Load all TFRecord samples into memory + for (size_t i = 0; i < sample_reader_.Size(); i++) { + tensorflow::tstring record = sample_reader_.ReadRecord(i); + tensorflow::Example example; + example.ParseFromString(record); + std::string input = GetFeatureValues("input", example).Get(0); + std::string answer = GetFeatureValues("answer", example).Get(0); + + auto sample = std::make_unique(); + sample->input = input; + sample->correct_answer = answer; + + samples_.push_back(std::move(sample)); + } +} + +void MmluGen::LoadSamplesToRam(const std::vector& samples) { + for (auto id : samples) { + loaded_sample_ids_.insert(id); + } +} + +void MmluGen::UnloadSamplesFromRam(const std::vector& samples) { + for (auto id : samples) { + loaded_sample_ids_.erase(id); + } +} + +std::vector MmluGen::GetData(int sample_idx) { + std::vector data; + if (sample_idx < samples_.size()) { + data.push_back(reinterpret_cast(const_cast(samples_[sample_idx]->input.c_str()))); + } + return data; +} + +std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vector& outputs) { + if (sample_idx >= samples_.size() || outputs.empty()) return {0}; + + const char* prediction = reinterpret_cast(outputs[0]); + char predicted_char = prediction[0]; // Assume first token is the answer (e.g., 'A', 'B', ...) + + const std::string& correct = samples_[sample_idx]->correct_answer; + bool is_correct = (predicted_char == correct[0]); + + total_++; + if (is_correct) correct_++; + + return {static_cast(is_correct)}; +} + +bool MmluGen::HasAccuracy() { + return true; +} + +float MmluGen::ComputeAccuracy() { + return total_ > 0 ? static_cast(correct_) / total_ : 0.0f; +} + +std::string MmluGen::ComputeAccuracyString() { + float acc = ComputeAccuracy(); + return "Accuracy: " + std::to_string(acc * 100.0f) + "%"; +} + +} // namespace mobile +} // namespace mlperf diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h new file mode 100644 index 000000000..4844a2be8 --- /dev/null +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -0,0 +1,61 @@ +#ifndef MLPERF_DATASETS_MMLU_GEN_H_ +#define MLPERF_DATASETS_MMLU_GEN_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "flutter/cpp/dataset.h" +#include "flutter/cpp/datasets/squad_utils/tfrecord_reader.h" + +namespace mlperf { +namespace mobile { + +class MmluGen : public Dataset { + public: + MmluGen(Backend* backend, const std::string& input_tfrecord); + + const std::string& Name() override { return name_; } + + size_t TotalSampleCount() override { return samples_.size(); } + + void LoadSamplesToRam(const std::vector& samples) override; + + void UnloadSamplesFromRam(const std::vector& samples) override; + + std::vector GetData(int sample_idx) override; + + std::vector ProcessOutput(const int sample_idx, const std::vector& outputs) override; + + bool HasAccuracy() override; + + float ComputeAccuracy() override; + + std::string ComputeAccuracyString() override; + + private: + const std::string name_ = "MmluGen"; + + TFRecordReader sample_reader_; + + struct PromptSample { + std::string input; + std::string correct_answer; + }; + + std::vector> samples_; + std::set loaded_sample_ids_; + + size_t correct_ = 0; + size_t total_ = 0; +}; + +} // namespace mobile +} // namespace mlperf + +#endif // MLPERF_DATASETS_MMLU_GEN_H_ diff --git a/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py b/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py new file mode 100644 index 000000000..a87a389da --- /dev/null +++ b/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py @@ -0,0 +1,37 @@ +import tensorflow as tf +import pandas as pd +import argparse + +def parse_args(): + parser = argparse.ArgumentParser(description="Convert a CSV of LLM prompts to TFRecord format.") + parser.add_argument('--input', type=str, required=True, help="Path to the input CSV file.") + parser.add_argument('--output', type=str, required=True, help="Path to the output TFRecord file.") + return parser.parse_args() + +def map_answer(num): + return {1: "A", 2: "B", 3: "C", 4: "D"}.get(num, "X") # Use 'X' as fallback + +def create_example(input_text, answer_letter): + return tf.train.Example(features=tf.train.Features(feature={ + "input": tf.train.Feature(bytes_list=tf.train.BytesList(value=[input_text.encode()])), + "answer": tf.train.Feature(bytes_list=tf.train.BytesList(value=[answer_letter.encode()])), + })) + +def main(): + args = parse_args() + df = pd.read_csv(args.input_csv) + + if "input_formatted" not in df.columns or "answer" not in df.columns: + raise ValueError("CSV must contain 'input_formatted' and 'answer' columns.") + + df["answer_letter"] = df["answer"].map(map_answer) + + with tf.io.TFRecordWriter(args.output_tfrecord) as writer: + for _, row in df.iterrows(): + example = create_example(row["input_formatted"], row["answer_letter"]) + writer.write(example.SerializeToString()) + + print(f"TFRecord written to: {args.output_tfrecord}") + +if __name__ == "__main__": + main() diff --git a/flutter/cpp/flutter/BUILD b/flutter/cpp/flutter/BUILD index eb5ddb103..d6bc26a16 100644 --- a/flutter/cpp/flutter/BUILD +++ b/flutter/cpp/flutter/BUILD @@ -32,6 +32,7 @@ cc_library( "//flutter/cpp/datasets:coco", "//flutter/cpp/datasets:coco_gen", "//flutter/cpp/datasets:imagenet", + "//flutter/cpp/datasets:mmlu_gen", "//flutter/cpp/datasets:snu_sr", "//flutter/cpp/datasets:squad", "//flutter/cpp/proto:mlperf_task_cc_proto", diff --git a/flutter/cpp/proto/mlperf_task.proto b/flutter/cpp/proto/mlperf_task.proto index 63b0a4b84..3eeb843a6 100644 --- a/flutter/cpp/proto/mlperf_task.proto +++ b/flutter/cpp/proto/mlperf_task.proto @@ -80,6 +80,7 @@ message DatasetConfig { ADE20K = 4; SNUSR = 5; COCOGEN = 6; + MMLU = 7; } required DatasetType type = 1; // Config of the dataset. diff --git a/flutter/third_party/BUILD b/flutter/third_party/BUILD new file mode 100644 index 000000000..e9811081e --- /dev/null +++ b/flutter/third_party/BUILD @@ -0,0 +1,247 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +common_linkopts = tflite_linkopts() + select({ + "//conditions:default": [], + "//tensorflow:android": [ + "-pie", + "-llog", + ], +}) + +exports_files(glob([ + "testdata/*.jpg", +])) + +cc_library( + name = "image_preprocessing_stage", + srcs = ["image_preprocessing_stage.cc"], + hdrs = ["image_preprocessing_stage.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite:string", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:types", + "//tensorflow/lite/profiling:time", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:preprocessing_steps_cc_proto", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", + "@libjpeg_turbo//:jpeg", + "@local_xla//xla/tsl/util:stats_calculator_portable", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_jpeg_internal", + ], + "//conditions:default": [ + "//tensorflow/core:jpeg_internal", + ], + }), +) + +cc_test( + name = "image_preprocessing_stage_test", + srcs = ["image_preprocessing_stage_test.cc"], + data = ["testdata/grace_hopper.jpg"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":image_preprocessing_stage", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "topk_accuracy_eval_stage", + srcs = ["topk_accuracy_eval_stage.cc"], + hdrs = ["topk_accuracy_eval_stage.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + ], +) + +cc_test( + name = "topk_accuracy_eval_stage_test", + srcs = ["topk_accuracy_eval_stage_test.cc"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":topk_accuracy_eval_stage", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tflite_inference_stage", + srcs = ["tflite_inference_stage.cc"], + hdrs = ["tflite_inference_stage.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite:framework", + "//tensorflow/lite/core:framework", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/core/kernels:builtin_ops", + "//tensorflow/lite/profiling:time", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation:utils", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_absl//absl/base:core_headers", + "@local_xla//xla/tsl/util:stats_calculator_portable", + ], +) + +cc_test( + name = "tflite_inference_stage_test", + srcs = ["tflite_inference_stage_test.cc"], + data = ["//tensorflow/lite:testdata/add_quantized.bin"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":tflite_inference_stage", + "//tensorflow/lite:framework", + "//tensorflow/lite/core:framework", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/tools/evaluation:utils", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "image_classification_stage", + srcs = ["image_classification_stage.cc"], + hdrs = ["image_classification_stage.h"], + copts = tflite_copts(), + deps = [ + ":image_preprocessing_stage", + ":tflite_inference_stage", + ":topk_accuracy_eval_stage", + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation:utils", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + ], +) + +cc_library( + name = "inference_profiler_stage", + srcs = ["inference_profiler_stage.cc"], + hdrs = ["inference_profiler_stage.h"], + copts = tflite_copts(), + deps = [ + ":tflite_inference_stage", + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@FP16", + "@local_xla//xla/tsl/util:stats_calculator_portable", + ], +) + +cc_test( + name = "inference_profiler_stage_test", + srcs = ["inference_profiler_stage_test.cc"], + data = ["//tensorflow/lite:testdata/add_quantized.bin"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":inference_profiler_stage", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/delegates/nnapi:nnapi_delegate", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "object_detection_average_precision_stage", + srcs = ["object_detection_average_precision_stage.cc"], + hdrs = ["object_detection_average_precision_stage.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "//tensorflow/lite/tools/evaluation/stages/utils:image_metrics", + ], +) + +cc_test( + name = "object_detection_average_precision_stage_test", + srcs = ["object_detection_average_precision_stage_test.cc"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":object_detection_average_precision_stage", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "object_detection_stage", + srcs = ["object_detection_stage.cc"], + hdrs = ["object_detection_stage.h"], + copts = tflite_copts(), + deps = [ + ":image_preprocessing_stage", + ":object_detection_average_precision_stage", + ":tflite_inference_stage", + "//tensorflow/core:tflite_portable_logging", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", + "//tensorflow/lite/tools/evaluation:evaluation_stage", + "//tensorflow/lite/tools/evaluation:utils", + "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", + "//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + ], +) diff --git a/flutter/third_party/enable-png-in-tensorflow-lite-tools-evaluation.patch b/flutter/third_party/enable-png-in-tensorflow-lite-tools-evaluation.patch index 92a1f9003..907c979a2 100644 --- a/flutter/third_party/enable-png-in-tensorflow-lite-tools-evaluation.patch +++ b/flutter/third_party/enable-png-in-tensorflow-lite-tools-evaluation.patch @@ -13,7 +13,7 @@ diff --git a/tensorflow/lite/tools/evaluation/stages/BUILD b/tensorflow/lite/too index 9f649588145..e81b284709c 100644 --- a/tensorflow/lite/tools/evaluation/stages/BUILD +++ b/tensorflow/lite/tools/evaluation/stages/BUILD -@@ -53,9 +53,11 @@ cc_library( +@@ -57,9 +57,11 @@ cc_library( ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_jpeg_internal", @@ -29,14 +29,11 @@ diff --git a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.c index a1418c3bcb6..a9750141b3d 100644 --- a/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc -@@ -29,6 +29,7 @@ limitations under the License. +@@ -29,3 +29,4 @@ limitations under the License. #include "absl/strings/ascii.h" - #include "tensorflow/core/lib/jpeg/jpeg_handle.h" + #include "jpeglib.h" // from @libjpeg_turbo #include "tensorflow/core/lib/jpeg/jpeg_mem.h" +#include "tensorflow/core/lib/png/png_io.h" - #include "tensorflow/core/platform/logging.h" - #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" - #include "tensorflow/lite/kernels/internal/types.h" @@ -107,6 +108,29 @@ inline void LoadImageJpeg(std::string* filename, ImageData* image_data) { image_data->data.reset(float_image); } diff --git a/mobile_back_tflite/cpp/backend_tflite/BUILD b/mobile_back_tflite/cpp/backend_tflite/BUILD index 651e34eba..41d4cb07a 100644 --- a/mobile_back_tflite/cpp/backend_tflite/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/BUILD @@ -54,6 +54,7 @@ cc_library( "single_model_pipeline.cc", "stable_diffusion_invoker.cc", "stable_diffusion_pipeline.cc", + "llm_pipeline.cc", "tflite_c.cc", ], hdrs = [ @@ -63,6 +64,7 @@ cc_library( "single_model_pipeline.h", "stable_diffusion_invoker.h", "stable_diffusion_pipeline.h", + "llm_pipeline.h", "tflite_settings_android.h", "tflite_settings_apple.h", "tflite_settings_windows.h", @@ -83,11 +85,17 @@ cc_library( ":tflite_settings", "//flutter/cpp:utils", "//flutter/cpp/c:headers", + "@com_google_sentencepiece//:sentencepiece_processor", "@org_tensorflow//tensorflow/core:tflite_portable_logging", "@org_tensorflow//tensorflow/lite/c:c_api", + "@org_tensorflow//tensorflow/lite/c:c_api_experimental", "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:util", +# "@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops", ] + select({ "@org_tensorflow//tensorflow:android": [ + "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:delegate", ], "@org_tensorflow//tensorflow:ios": [ diff --git a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt index a0ca490e8..11eb346bf 100644 --- a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt +++ b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt @@ -257,3 +257,26 @@ benchmark_setting { value: "timestep_embeddings_data.bin.ts" } } + +benchmark_setting { + benchmark_id: "llm" + framework: "TFLite" + delegate_choice: { + delegate_name: "CPU" + accelerator_name: "cpu" + accelerator_desc: "CPU" + model_file: { + model_path: "https://thee.dev/mlc/model.tflite" #Placeholder + model_checksum: "04f62ae20a0f1c68c138f30d88411be0" + } + } + delegate_selected: "CPU" + custom_setting { + id: "pipeline" + value: "LLMPipeline" + } + custom_setting { + id: "sentencepiece_processor_path" + value: "llama.model.sp" + } +} diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc new file mode 100644 index 000000000..f802bf84f --- /dev/null +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -0,0 +1,380 @@ +/* Copyright 2020-2021 The MLPerf Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "llm_pipeline.h" + +#include +#include + +#if defined(MTK_TFLITE_NEURON_BACKEND) && defined(__ANDROID__) +#include + +#include "neuron/APUWareUtilsApi.h" +#endif + +#include "flutter/cpp/c/type.h" +#include "flutter/cpp/utils.h" +#include "tensorflow/lite/c/c_api.h" +#include "tensorflow/lite/c/common.h" +#if __ANDROID__ +#include + +#if MTK_TFLITE_NEURON_BACKEND +#include "neuron/neuron_backend.h" +#include "neuron/neuron_builder.h" +#include "neuron/neuron_delegate.h" +#endif + +#include "tensorflow/lite/delegates/gpu/delegate.h" +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#endif +#include "tensorflow/core/platform/logging.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +static bool backendExists = false; + +#if __ANDROID__ +bool is_emulator() { + char ro_build_characteristics[PROP_VALUE_MAX + 1]; + if (__system_property_get("ro.build.characteristics", + ro_build_characteristics)) { + char *ptr; + ptr = strstr(ro_build_characteristics, "emulator"); + if (ptr) return true; + } + return false; +} +#endif + +// Destroy the backend pointer and its data. +void LLMPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + if (backend_data) { + TfLiteModelDelete(backend_data->model); + TfLiteSignatureRunnerDelete(backend_data->prefill_runner); + TfLiteSignatureRunnerDelete(backend_data->decode_runner); + TfLiteInterpreterDelete(backend_data->interpreter); + delete backend_data->sp_processor; + delete backend_data; + } + backendExists = false; +} + +// Create a new backend and return the pointer to it. +// TODO add eos and bos tokens as config parameters +mlperf_backend_ptr_t LLMPipeline::backend_create(const char *model_path, mlperf_backend_configuration_t *configs, const char *native_lib_path) { + // Verify only one instance of the backend exists at any time + if (backendExists) { + LOG(ERROR) << "Only one backend instance should exist at a time"; + return nullptr; + } + + LLMBackendData *backend_data = new LLMBackendData(); + + // sentencePiece Processor Path + std::string sppp = mlperf::mobile::GetConfigValue(configs, "sentencepiece_processor_path", std::string("")); + + // Load the model. + backend_data->model = TfLiteModelCreateFromFile(model_path); + if (!backend_data->model) { + LOG(ERROR) << "Failed to load model: " << model_path; + backend_delete(backend_data); + return nullptr; + } + + backend_data->interpreter = BuildInterpreter(backend_data->model, backend_data->threads); + if (!backend_data->interpreter) { + LOG(ERROR) << "Failed to load interpreter"; + backend_delete(backend_data); + return nullptr; + } + + backend_data->kv_cache = BuildKVCache(backend_data->interpreter); + //TODO kv_cache check + + backend_data->decode_runner = GetDecodeRunner(backend_data->interpreter, backend_data->kv_cache); + + backend_data->sp_processor = LoadSentencePieceProcessor(sppp); + if (!backend_data->sp_processor) { + LOG(ERROR) << "Failed to load sentencepiece processor: " << sppp; + backend_delete(backend_data); + return nullptr; + } + + return backend_data; +} + +// Vendor name who create this backend. +const char *LLMPipeline::backend_vendor_name(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + return backend_data->vendor; +} + +// TODO: Return the name of the accelerator. +const char *LLMPipeline::backend_accelerator_name(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + return backend_data->accelerator; +} + +// Return the name of this backend. +const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + return backend_data->name; +} + +// Run the inference for a sample. +mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + + // Get Input Tensors for each of the runners. + // Shape: [Batch, Seq], Dtype: int32 + TfLiteTensor* prefill_input = TfLiteSignatureRunnerGetInputTensor(backend_data->prefill_runner, "tokens"); + // Shape: [Seq], Dtype: int32 + TfLiteTensor* prefill_input_pos = TfLiteSignatureRunnerGetInputTensor(backend_data->prefill_runner, "input_pos"); + // Shape: [Batch, Seq], Dtype: int32 + TfLiteTensor* decode_input = TfLiteSignatureRunnerGetInputTensor(backend_data->decode_runner, "tokens"); + // Shape: [Seq], Dtype: int32 + TfLiteTensor* decode_input_pos = TfLiteSignatureRunnerGetInputTensor(backend_data->decode_runner, "input_pos"); + // shape: [Batch, kv_cache_max, num_query_groups, head_dim] + TfLiteTensor* kv_cache_k_0 = TfLiteSignatureRunnerGetInputTensor(backend_data->decode_runner, "kv_cache_k_0"); + + int max_seq_size = prefill_input->dims->data[1]; + int kv_cache_max_size = kv_cache_k_0->dims->data[1]; + int prefill_seq_size = std::min(static_cast(backend_data->prompt_tokens.size()), max_seq_size); + + std::memset(prefill_input->data.i32, 0, prefill_input->bytes); + std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes); + for (int i = 0; i < prefill_seq_size - 1; ++i) { + prefill_input->data.i32[i] = backend_data->prompt_tokens[i]; + prefill_input_pos->data.i32[i] = i; + } + + MINIMAL_CHECK(TfLiteSignatureRunnerInvoke(backend_data->prefill_runner) == kTfLiteOk); + + int decode_steps = kv_cache_max_size - prefill_seq_size; + MINIMAL_CHECK(decode_steps > 0); + + std::vector output_tokens; + output_tokens.reserve(decode_steps); + int next_token = backend_data->prompt_tokens[prefill_seq_size - 1]; + int next_position = prefill_seq_size - 1; + for (int i = 0; i < decode_steps; ++i) { + decode_input->data.i32[0] = next_token; + decode_input_pos->data.i32[0] = next_position; + MINIMAL_CHECK(TfLiteSignatureRunnerInvoke(backend_data->decode_runner) == kTfLiteOk); + next_token = GreedySampler(TfLiteSignatureRunnerGetOutputTensor(backend_data->decode_runner, "logits")); + output_tokens.push_back(next_token); + next_position += 1; + if (next_token == backend_data->stop_token_id) break; + } + + MINIMAL_CHECK(backend_data->sp_processor->Decode(output_tokens, &backend_data->output).ok()); + + return MLPERF_SUCCESS; +} + +// Flush the staged queries immediately. +mlperf_status_t LLMPipeline::backend_flush_queries(mlperf_backend_ptr_t backend_ptr) { + return MLPERF_SUCCESS; +} + +// Return the number of inputs of the model. +int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { + return 2; +} + +// Return the type of the ith input. +mlperf_data_t LLMPipeline::backend_get_input_type(mlperf_backend_ptr_t backend_ptr, int32_t i) { + return mlperf_data_t{mlperf_data_t::Int32, 0}; +} + +// Set the data for ith input. +mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, int32_t batch_index, int32_t i, void *data) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + + std::string prompt = std::string(static_cast(data)); + MINIMAL_CHECK(backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens).ok()); //TODO + + if (!backend_data->start_token.empty()) { + backend_data->prompt_tokens.insert(backend_data->prompt_tokens.begin(), backend_data->sp_processor->PieceToId((backend_data->start_token))); + } + + // NOTE block below can be moved safely to backend_create + if (!backend_data->end_token.empty()) { + backend_data->stop_token_id = backend_data->sp_processor->PieceToId((backend_data->end_token)); + } + // --- + + uint16_t effective_prefill_token_size = backend_data->prompt_tokens.size() - 1; //assuming max tokens is <16k + + backend_data->prefill_runner = GetPrefillRunner(backend_data->interpreter, effective_prefill_token_size, backend_data->kv_cache); + + + return MLPERF_SUCCESS; +} + +// Return the number of outputs for the model. +int32_t LLMPipeline::backend_get_output_count(mlperf_backend_ptr_t backend_ptr) { + return 1; +} + +// Return the type of ith output. +mlperf_data_t LLMPipeline::backend_get_output_type(mlperf_backend_ptr_t backend_ptr, int32_t i) { + return mlperf_data_t{mlperf_data_t::Float32, 0}; +} + +// Get the data from ith output. +mlperf_status_t LLMPipeline::backend_get_output(mlperf_backend_ptr_t backend_ptr, uint32_t batch_index, int32_t i, void **data) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + + if (i == 0) { + *data = backend_data->output.data(); + return MLPERF_SUCCESS; + } + + return MLPERF_FAILURE; +} + +void LLMPipeline::backend_convert_inputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, uint8_t *data) {} + +void LLMPipeline::backend_convert_outputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, uint8_t *data) {} + +void *LLMPipeline::backend_get_buffer(size_t n) { + return ::operator new(n); +} + +void LLMPipeline::backend_release_buffer(void *p) { + ::operator delete(p); +} + +TfLiteInterpreter *LLMPipeline::BuildInterpreter(TfLiteModel *model, int num_threads) { + tflite::ops::builtin::BuiltinOpResolver resolver; + // NOTE: We need to manually register optimized OPs for KV-cache and + // Scaled Dot Product Attention (SDPA). + tflite::ops::custom::GenAIOpsRegisterer(&resolver); + tflite::InterpreterBuilder builder(*model, resolver); + //TODO + MINIMAL_CHECK(builder.SetNumThreads(num_threads) == kTfLiteOk); + TfLiteInterpreter *interpreter; + builder(&interpreter); + //TODO + MINIMAL_CHECK(interpreter != nullptr); + + return interpreter; +} + +kv_cache_t LLMPipeline::BuildKVCache(TfLiteInterpreter *interpreter) { + TfLiteSignatureRunner *runner = interpreter->GetSignatureRunner("decode"); + // TODO + if (runner == nullptr) { + return {}; + } + // The two arguments excluded are `tokens` and `input_pos`. + // TODO more arguments might need to be excluded + size_t num_layers = (TfLiteSignatureRunnerGetInputCount(runner) - 2) / 2; + if (num_layers == 0) { + return {}; + } + + kv_cache_t kv_cache; + for (int i = 0; i < num_layers; ++i) { + std::string k_cache_name = "kv_cache_k_" + std::to_string(i); + std::string v_cache_name = "kv_cache_v_" + std::to_string(i); + // We are assuming K and V tensors are of the same shape. + TfLiteTensor* tensor = TfLiteSignatureRunnerGetInputTensor(runner, k_cache_name.c_str()); + size_t count = tensor->bytes / sizeof(float); + kv_cache.emplace(k_cache_name, + std::vector>(count, 0.0)); + kv_cache.emplace(v_cache_name, + std::vector>(count, 0.0)); + } + + return kv_cache; +} + +void LLMPipeline::PrepareRunner(tflite::SignatureRunner* runner, kv_cache_t& kv_cache) { + for (auto& [name, cache] : kv_cache) { + TfLiteCustomAllocation allocation = { + .data = static_cast(cache.data()), + .bytes = cache.size() * sizeof(float)}; + // Both input and output tensors are set to the same buffer. Not all + // delegates support this in-place update. For those cases, we need to do + // a ping-pong buffer and update the pointers between inference calls. + //TODO + MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor(name.c_str(), allocation) == kTfLiteOk); + //TODO + MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor(name.c_str(), allocation) == kTfLiteOk); + } + //TODO + MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk); +} + +TfLiteSignatureRunner *LLMPipeline::GetPrefillRunner(TfLiteInterpreter* interpreter, std::size_t num_input_tokens, kv_cache_t& kv_cache) { + // Find the prefill signature length that best matches the input token size. + TfLiteSignatureRunner* runner = nullptr; + //int best_seq_size = -1; + size_t delta = std::numeric_limits::max(); + for (int32_t i = 0; i < TfLiteInterpreterGetSignatureCount(interpreter); i++) { + std::string key (TfLiteInterpreterGetSignatureKey(interpreter, i)); + if (key.find("prefill") == std::string::npos) continue; + TfLiteTensor* input_pos = TfLiteSignatureRunnerGetInputTensor(TfLiteInterpreterGetSignatureRunner(interpreter, key.c_str()), "input_pos"); + // The expected shape for input position is [Seq]. + size_t seq_size = input_pos->dims->data[0]; + if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) { + runner = TfLiteInterpreterGetSignatureRunner(interpreter, key->c_str()); + //best_seq_size = seq_size; + delta = seq_size - num_input_tokens; + } + } + MINIMAL_CHECK(runner != nullptr); + PrepareRunner(runner->impl, kv_cache); + return runner; +} + +TfLiteSignatureRunner *LLMPipeline::GetDecodeRunner(TfLiteInterpreter* interpreter, kv_cache_t& kv_cache) { + TfLiteSignatureRunner* runner = TfLiteInterpreterGetSignatureRunner(interpreter, "decode"); + MINIMAL_CHECK(runner != nullptr); + PrepareRunner(runner->impl, kv_cache); + return runner; +} + +sentencepiece::SentencePieceProcessor *LLMPipeline::LoadSentencePieceProcessor(std::string path) { + std::ifstream input(path, std::ios::binary); + std::string serialized_proto = std::string( + std::istreambuf_iterator(input), std::istreambuf_iterator()); + auto processor = new sentencepiece::SentencePieceProcessor(); + MINIMAL_CHECK(processor->LoadFromSerializedProto(serialized_proto).ok()); + return processor; +} + +// A basic greedy sampler (equivalent to argmax). +int LLMPipeline::GreedySampler(const TfLiteTensor* logits) { + float max_value = -std::numeric_limits::infinity(); + int max_index = 0; + // logits shape: [Batch, Seq, Vocab], Dtype: float + for (int i = 0; i < logits->dims->data[2]; ++i) { + if (logits->data.f[i] > max_value) { + max_value = logits->data.f[i]; + max_index = i; + } + } + return max_index; +} + +#ifdef __cplusplus +}; +#endif // __cplusplus diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h new file mode 100644 index 000000000..cc24e9808 --- /dev/null +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -0,0 +1,159 @@ +/* Copyright 2024 The MLPerf Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TFLITE_LLM_PIPELINE_H_ +#define TFLITE_LLM_PIPELINE_H_ + +#include +#include +#include + +#include "flutter/cpp/c/type.h" +#include "pipeline.h" +#include "tensorflow/lite/c/c_api.h" +#include "tensorflow/lite/c/c_api_experimental.h" +#include "tensorflow/lite/experimental/genai/genai_ops.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/signature_runner.h" +#include "tensorflow/lite/util.h" +#include "src/sentencepiece_processor.h" + +#include "tensorflow/core/platform/logging.h" + +#define MINIMAL_CHECK(x) \ +if (!(x)) { \ + LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return MLPERF_FAILURE; \ +} + +// TF Lite requires all buffers (including external buffers used for KV cache +// here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use +// this custom allocator. Please use with caution as different platforms may +// have different alignment requirements. +template +class AlignedAllocator { + public: + using value_type = T; + + T* allocate(std::size_t n) { + void* ptr; + std::size_t size = n * sizeof(T); + std::size_t padding = tflite::kDefaultTensorAlignment - + (size % tflite::kDefaultTensorAlignment); + size += padding; + int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); + if (ret != 0) { + return nullptr; + } + return static_cast(ptr); + }; + + void deallocate(T* ptr, std::size_t n) { free(ptr); } +}; + +using kv_cache_t = std::map>>; + +struct LLMBackendData { + const char *name = "TFLite"; + const char *vendor = "Google"; + const char *accelerator = "CPU"; + TfLiteModel *model{nullptr}; + sentencepiece::SentencePieceProcessor *sp_processor{nullptr}; + //TfLiteInterpreterOptions *options{}; TODO use this to allow different delegates other than CPU? + TfLiteInterpreter *interpreter{}; + TfLiteSignatureRunner *prefill_runner{nullptr}; + TfLiteSignatureRunner *decode_runner{nullptr}; + kv_cache_t kv_cache; + //std::string input_prompt; + std::vector prompt_tokens; + uint8_t threads = 1; + std::string start_token = ""; + std::string end_token = ""; + int stop_token_id = -1; + std::string output; + +// uint32_t real_batch_size = 1; +//std::unique_ptr executer; +// int32_t original_tensor_size = 0; +//#ifdef MTK_TFLITE_NEURON_BACKEND +// neuron_backend_ptr_t neuronBackendData{nullptr}; +//#endif +}; + +// A simple pipeline which runs a single model. +class LLMPipeline : public Pipeline { + public: + LLMPipeline() = default; + + ~LLMPipeline() override = default; + + void backend_delete(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_backend_ptr_t backend_create(const char *model_path, + mlperf_backend_configuration_t *configs, + const char *native_lib_path) override; + + const char *backend_vendor_name(mlperf_backend_ptr_t backend_ptr) override; + + const char *backend_accelerator_name( + mlperf_backend_ptr_t backend_ptr) override; + + const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_status_t backend_issue_query( + mlperf_backend_ptr_t backend_ptr) override; + + mlperf_status_t backend_flush_queries( + mlperf_backend_ptr_t backend_ptr) override; + + int32_t backend_get_input_count(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_data_t backend_get_input_type(mlperf_backend_ptr_t backend_ptr, + int32_t i) override; + + mlperf_status_t backend_set_input(mlperf_backend_ptr_t backend_ptr, + int32_t batch_index, int32_t i, + void *data) override; + + int32_t backend_get_output_count(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_data_t backend_get_output_type(mlperf_backend_ptr_t backend_ptr, + int32_t i) override; + + mlperf_status_t backend_get_output(mlperf_backend_ptr_t backend_ptr, + uint32_t batchIndex, int32_t i, + void **data) override; + + void backend_convert_inputs(mlperf_backend_ptr_t backend_ptr, int bytes, + int width, int height, uint8_t *data) override; + + void backend_convert_outputs(mlperf_backend_ptr_t backend_ptr, int bytes, + int width, int height, uint8_t *data) override; + + void *backend_get_buffer(size_t n) override; + + void backend_release_buffer(void *p) override; + + private: + TfLiteInterpreter *BuildInterpreter(TfLiteModel *model, int num_threads); + kv_cache_t BuildKVCache(TfLiteInterpreter *interpreter); + void PrepareRunner(tflite::SignatureRunner *runner, kv_cache_t &kv_cache); + TfLiteSignatureRunner *GetPrefillRunner(TfLiteInterpreter *interpreter, std::size_t num_input_tokens, kv_cache_t &kv_cache); + TfLiteSignatureRunner *GetDecodeRunner(TfLiteInterpreter *interpreter, kv_cache_t &kv_cache); + sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor(std::string path); + int GreedySampler(const TfLiteTensor *logits); +}; + +#endif // TFLITE_SINGLE_MODEL_PIPELINE_H_ diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc index ce1eb7a1d..a75985373 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc @@ -654,4 +654,4 @@ void SingleModelPipeline::backend_release_buffer(void *p) { #ifdef __cplusplus }; -#endif // __cplusplus \ No newline at end of file +#endif // __cplusplus diff --git a/patches/com_google_sentencepiece.diff b/patches/com_google_sentencepiece.diff new file mode 100644 index 000000000..946231276 --- /dev/null +++ b/patches/com_google_sentencepiece.diff @@ -0,0 +1,2357 @@ +diff --git a/src/bpe_model.cc b/src/bpe_model.cc +index 22cd115..97e0bda 100644 +--- a/src/bpe_model.cc ++++ b/src/bpe_model.cc +@@ -21,7 +21,7 @@ + + #include "bpe_model.h" + #include "freelist.h" +-#include "third_party/absl/container/flat_hash_map.h" ++#include "absl/container/flat_hash_map.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/bpe_model_trainer.cc b/src/bpe_model_trainer.cc +index 964d44e..64878cd 100644 +--- a/src/bpe_model_trainer.cc ++++ b/src/bpe_model_trainer.cc +@@ -18,7 +18,8 @@ + #include + + #include "bpe_model_trainer.h" +-#include "third_party/absl/container/flat_hash_set.h" ++#include "absl/container/flat_hash_set.h" ++#include "absl/status/status.h" + #include "util.h" + + namespace sentencepiece { +@@ -171,7 +172,7 @@ void Trainer::UpdateActiveSymbols() { + active_symbols_.insert(symbols.begin(), symbols.begin() + size); + } + +-util::Status Trainer::Train() { ++absl::Status Trainer::Train() { + RETURN_IF_ERROR(status()); + + CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces()); +diff --git a/src/bpe_model_trainer.h b/src/bpe_model_trainer.h +index e011a37..a17e580 100644 +--- a/src/bpe_model_trainer.h ++++ b/src/bpe_model_trainer.h +@@ -20,7 +20,8 @@ + #include + + #include "sentencepiece_model.pb.h" +-#include "third_party/absl/container/flat_hash_map.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/status/status.h" + #include "trainer_interface.h" + + namespace sentencepiece { +@@ -35,7 +36,7 @@ class Trainer : public TrainerInterface { + : TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec, + denormalizer_spec) {} + +- util::Status Train() override; ++ absl::Status Train() override; + + private: + // Symbol represents a character or symbol bigram. +diff --git a/src/bpe_model_trainer_test.cc b/src/bpe_model_trainer_test.cc +index 173eb9c..2a43c3a 100644 +--- a/src/bpe_model_trainer_test.cc ++++ b/src/bpe_model_trainer_test.cc +@@ -20,8 +20,8 @@ + #include "sentencepiece_processor.h" + #include "sentencepiece_trainer.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/builder.cc b/src/builder.cc +index 378aaa0..fd8edf8 100644 +--- a/src/builder.cc ++++ b/src/builder.cc +@@ -18,10 +18,11 @@ + + #include "builder.h" + #include "filesystem.h" +-#include "third_party/absl/strings/str_join.h" +-#include "third_party/absl/strings/str_replace.h" +-#include "third_party/absl/strings/str_split.h" +-#include "third_party/absl/strings/strip.h" ++#include "absl/strings/str_join.h" ++#include "absl/strings/str_replace.h" ++#include "absl/strings/str_split.h" ++#include "absl/strings/strip.h" ++#include "absl/status/status.h" + + #ifdef ENABLE_NFKC_COMPILE + #include +@@ -36,7 +37,7 @@ + + #include "normalization_rule.h" + #include "normalizer.h" +-#include "third_party/darts_clone/darts.h" ++#include "include/darts.h" + #include "util.h" + + namespace sentencepiece { +@@ -145,7 +146,7 @@ Builder::Chars Normalize(const Builder::CharsMap &chars_map, + } // namespace + + // static +-util::Status Builder::CompileCharsMap(const CharsMap &chars_map, ++absl::Status Builder::CompileCharsMap(const CharsMap &chars_map, + std::string *output) { + CHECK_OR_RETURN(output); + CHECK_OR_RETURN(!chars_map.empty()); +@@ -212,7 +213,7 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map, + } + + // static +-util::Status Builder::DecompileCharsMap(absl::string_view blob, ++absl::Status Builder::DecompileCharsMap(absl::string_view blob, + Builder::CharsMap *chars_map) { + CHECK_OR_RETURN(chars_map); + chars_map->clear(); +@@ -265,7 +266,7 @@ util::Status Builder::DecompileCharsMap(absl::string_view blob, + } + + // static +-util::Status Builder::GetPrecompiledCharsMap(const std::string &name, ++absl::Status Builder::GetPrecompiledCharsMap(const std::string &name, + std::string *output) { + CHECK_OR_RETURN(output); + +@@ -282,12 +283,12 @@ util::Status Builder::GetPrecompiledCharsMap(const std::string &name, + return util::OkStatus(); + } + } +- return util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC) ++ return util::StatusBuilder(absl::StatusCode::kNotFound, GTL_LOC) + << "No precompiled charsmap is found: " << name; + } + + // static +-util::Status Builder::BuildNFKCMap(CharsMap *chars_map) { ++absl::Status Builder::BuildNFKCMap(CharsMap *chars_map) { + #ifdef ENABLE_NFKC_COMPILE + LOG(INFO) << "Running BuildNFKCMap"; + +@@ -345,7 +346,7 @@ util::Status Builder::BuildNFKCMap(CharsMap *chars_map) { + return util::OkStatus(); + } + +-util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) { ++absl::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) { + #ifdef ENABLE_NFKC_COMPILE + LOG(INFO) << "Running BuildNmtNFKCMap"; + +@@ -420,7 +421,7 @@ util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) { + } + + // static +-util::Status Builder::MergeUnicodeCaseFoldMap(Builder::CharsMap *chars_map) { ++absl::Status Builder::MergeUnicodeCaseFoldMap(Builder::CharsMap *chars_map) { + #ifdef ENABLE_NFKC_COMPILE + for (auto &c : *chars_map) { + std::vector trg; +@@ -445,7 +446,7 @@ util::Status Builder::MergeUnicodeCaseFoldMap(Builder::CharsMap *chars_map) { + } + + // static +-util::Status Builder::BuildNFKC_CFMap(CharsMap *chars_map) { ++absl::Status Builder::BuildNFKC_CFMap(CharsMap *chars_map) { + #ifdef ENABLE_NFKC_COMPILE + CharsMap nfkc_map; + RETURN_IF_ERROR(Builder::BuildNFKCMap(&nfkc_map)); +@@ -460,7 +461,7 @@ util::Status Builder::BuildNFKC_CFMap(CharsMap *chars_map) { + } + + // static +-util::Status Builder::BuildNmtNFKC_CFMap(CharsMap *chars_map) { ++absl::Status Builder::BuildNmtNFKC_CFMap(CharsMap *chars_map) { + #ifdef ENABLE_NFKC_COMPILE + CharsMap nfkc_map; + RETURN_IF_ERROR(Builder::BuildNmtNFKCMap(&nfkc_map)); +@@ -475,7 +476,7 @@ util::Status Builder::BuildNmtNFKC_CFMap(CharsMap *chars_map) { + } + + // static +-util::Status Builder::LoadCharsMap(absl::string_view filename, ++absl::Status Builder::LoadCharsMap(absl::string_view filename, + CharsMap *chars_map) { + LOG(INFO) << "Loading mapping file: " << filename.data(); + CHECK_OR_RETURN(chars_map); +@@ -510,7 +511,7 @@ util::Status Builder::LoadCharsMap(absl::string_view filename, + } + + // static +-util::Status Builder::SaveCharsMap(absl::string_view filename, ++absl::Status Builder::SaveCharsMap(absl::string_view filename, + const Builder::CharsMap &chars_map) { + auto output = filesystem::NewWritableFile(filename); + RETURN_IF_ERROR(output->status()); +@@ -540,7 +541,7 @@ util::Status Builder::SaveCharsMap(absl::string_view filename, + } + + // static +-util::Status Builder::RemoveRedundantMap(CharsMap *chars_map) { ++absl::Status Builder::RemoveRedundantMap(CharsMap *chars_map) { + CHECK_OR_RETURN(chars_map); + + CharsMap new_chars_map; +diff --git a/src/builder.h b/src/builder.h +index 49d2884..8ad872c 100644 +--- a/src/builder.h ++++ b/src/builder.h +@@ -22,7 +22,8 @@ + #include "common.h" + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + namespace normalizer { +@@ -43,15 +44,15 @@ class Builder { + // String-to-string mapping. + using CharsMap = std::map; + +- static util::Status CompileCharsMap(const CharsMap &chars_map, ++ static absl::Status CompileCharsMap(const CharsMap &chars_map, + std::string *output); + + // Decompiles `blob` into `chars_map`. +- static util::Status DecompileCharsMap(absl::string_view blob, ++ static absl::Status DecompileCharsMap(absl::string_view blob, + CharsMap *chars_map); + + // Returns a pre-compiled binary index with `name`. +- static util::Status GetPrecompiledCharsMap(const std::string &name, ++ static absl::Status GetPrecompiledCharsMap(const std::string &name, + std::string *output); + + // Makes a normalization mapping based on NFKC. +@@ -89,30 +90,30 @@ class Builder { + // normalizer is the goal of SentencePiece. + // + // TODO(taku): Make NFC, NFD, and NFKD mapping if necessary. +- static util::Status BuildNFKCMap(CharsMap *chars_map); ++ static absl::Status BuildNFKCMap(CharsMap *chars_map); + + // Makes an NFKC-based mapping with NMT specific modifications around + // whitespaces. +- static util::Status BuildNmtNFKCMap(CharsMap *chars_map); ++ static absl::Status BuildNmtNFKCMap(CharsMap *chars_map); + + // Merge Unicode case folding mapping into `chars_map`. +- static util::Status MergeUnicodeCaseFoldMap(CharsMap *chars_map); ++ static absl::Status MergeUnicodeCaseFoldMap(CharsMap *chars_map); + + // Makes NFKC with Unicode case folding. +- static util::Status BuildNFKC_CFMap(CharsMap *chars_map); ++ static absl::Status BuildNFKC_CFMap(CharsMap *chars_map); + + // Makes NMT NFKC with Unicode case folding. +- static util::Status BuildNmtNFKC_CFMap(CharsMap *chars_map); ++ static absl::Status BuildNmtNFKC_CFMap(CharsMap *chars_map); + + // Builds Chars map save in `filename`. + // Format: + // src_uchar1 src_uchar2 ... trg_uchar1 trg_uchar2... + // (src|trg)_ucharX must be a hex of Unicode code point. +- static util::Status LoadCharsMap(absl::string_view filename, ++ static absl::Status LoadCharsMap(absl::string_view filename, + CharsMap *chars_map); + + // Saves Chars map to `filename` as TSV. +- static util::Status SaveCharsMap(absl::string_view filename, ++ static absl::Status SaveCharsMap(absl::string_view filename, + const CharsMap &chars_map); + + private: +@@ -121,7 +122,7 @@ class Builder { + // Removes redundant rules from `chars_map`. + // When char_maps have "aa" => "bb" and "a" => "b", the first + // rule is not necessary since the second rule can cover the first rule. +- static util::Status RemoveRedundantMap(CharsMap *chars_map); ++ static absl::Status RemoveRedundantMap(CharsMap *chars_map); + }; + } // namespace normalizer + } // namespace sentencepiece +diff --git a/src/builder_test.cc b/src/builder_test.cc +index 4acb7b3..1dee5c7 100644 +--- a/src/builder_test.cc ++++ b/src/builder_test.cc +@@ -18,7 +18,7 @@ + #include "normalizer.h" + #include "sentencepiece_trainer.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" ++#include "absl/strings/str_cat.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/char_model_trainer.cc b/src/char_model_trainer.cc +index f438d78..4f4c603 100644 +--- a/src/char_model_trainer.cc ++++ b/src/char_model_trainer.cc +@@ -16,12 +16,13 @@ + + #include "char_model.h" + #include "char_model_trainer.h" ++#include "absl/status/status.h" + #include "util.h" + + namespace sentencepiece { + namespace character { + +-util::Status Trainer::Train() { ++absl::Status Trainer::Train() { + RETURN_IF_ERROR(status()); + + CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces()); +diff --git a/src/char_model_trainer.h b/src/char_model_trainer.h +index e563819..a5d021c 100644 +--- a/src/char_model_trainer.h ++++ b/src/char_model_trainer.h +@@ -17,6 +17,7 @@ + + #include "sentencepiece_model.pb.h" + #include "trainer_interface.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + namespace character { +@@ -30,7 +31,7 @@ class Trainer : public TrainerInterface { + : TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec, + denormalizer_spec) {} + +- util::Status Train() override; ++ absl::Status Train() override; + }; + } // namespace character + } // namespace sentencepiece +diff --git a/src/char_model_trainer_test.cc b/src/char_model_trainer_test.cc +index 8c2e4b7..e8b4979 100644 +--- a/src/char_model_trainer_test.cc ++++ b/src/char_model_trainer_test.cc +@@ -19,8 +19,8 @@ + #include "filesystem.h" + #include "sentencepiece_processor.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/common.h b/src/common.h +index 7595634..3a2f4e1 100644 +--- a/src/common.h ++++ b/src/common.h +@@ -46,7 +46,7 @@ typedef int32_t int32; + typedef int64_t int64; + typedef uint8_t uint8; + typedef uint16_t uint16; +-typedef uint32_t char32; ++typedef int32_t char32; + typedef uint32_t uint32; + typedef uint64_t uint64; + +@@ -146,6 +146,7 @@ inline const char *BaseName(const char *path) { + } // namespace logging + } // namespace sentencepiece + ++#ifndef LOG + #define LOG(severity) \ + (::sentencepiece::logging::GetMinLogLevel() > \ + ::sentencepiece::logging::LOG_##severity) \ +@@ -156,6 +157,7 @@ inline const char *BaseName(const char *path) { + std::cerr << ::sentencepiece::logging::BaseName(__FILE__) << "(" \ + << __LINE__ << ") " \ + << "LOG(" << #severity << ") " ++#endif // LOG + + #define CHECK(condition) \ + (condition) ? 0 \ +diff --git a/src/compile_charsmap_main.cc b/src/compile_charsmap_main.cc +index c5a5188..e5db1d7 100644 +--- a/src/compile_charsmap_main.cc ++++ b/src/compile_charsmap_main.cc +@@ -22,8 +22,9 @@ + #include "filesystem.h" + #include "init.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/flags/flag.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + + using sentencepiece::normalizer::Builder; + +@@ -160,7 +161,7 @@ int main(int argc, char **argv) { + + const std::vector>> ++ std::function>> + kRuleList = {{"nfkc", Builder::BuildNFKCMap}, + {"nmt_nfkc", Builder::BuildNmtNFKCMap}, + {"nfkc_cf", Builder::BuildNFKC_CFMap}, +diff --git a/src/error.cc b/src/error.cc +index a226d98..ab4675d 100644 +--- a/src/error.cc ++++ b/src/error.cc +@@ -20,8 +20,8 @@ + #ifdef _USE_EXTERNAL_ABSL + // Naive workaround to define minloglevel on external absl package. + // We want to define them in other cc file. +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/flags/parse.h" ++#include "absl/flags/flag.h" ++#include "absl/flags/parse.h" + ABSL_FLAG(int32, minloglevel, 0, + "Messages logged at a lower level than this don't actually."); + #endif +diff --git a/src/filesystem.cc b/src/filesystem.cc +index 833c8f7..9a1b6c9 100644 +--- a/src/filesystem.cc ++++ b/src/filesystem.cc +@@ -15,7 +15,8 @@ + #include + + #include "filesystem.h" +-#include "third_party/absl/memory/memory.h" ++#include "absl/status/status.h" ++#include "absl/memory/memory.h" + #include "util.h" + + #if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE) +@@ -36,7 +37,7 @@ class PosixReadableFile : public ReadableFile { + is_binary ? std::ios::binary | std::ios::in + : std::ios::in)) { + if (!*is_) +- status_ = util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC) ++ status_ = util::StatusBuilder(absl::StatusCode::kNotFound, GTL_LOC) + << "\"" << filename.data() << "\": " << util::StrError(errno); + } + +@@ -44,7 +45,7 @@ class PosixReadableFile : public ReadableFile { + if (is_ != &std::cin) delete is_; + } + +- util::Status status() const { return status_; } ++ absl::Status status() const { return status_; } + + bool ReadLine(std::string *line) { + return static_cast(std::getline(*is_, *line)); +@@ -61,7 +62,7 @@ class PosixReadableFile : public ReadableFile { + } + + private: +- util::Status status_; ++ absl::Status status_; + std::istream *is_; + }; + +@@ -75,7 +76,7 @@ class PosixWritableFile : public WritableFile { + : std::ios::out)) { + if (!*os_) + status_ = +- util::StatusBuilder(util::StatusCode::kPermissionDenied, GTL_LOC) ++ util::StatusBuilder(absl::StatusCode::kPermissionDenied, GTL_LOC) + << "\"" << filename.data() << "\": " << util::StrError(errno); + } + +@@ -83,7 +84,7 @@ class PosixWritableFile : public WritableFile { + if (os_ != &std::cout) delete os_; + } + +- util::Status status() const { return status_; } ++ absl::Status status() const { return status_; } + + bool Write(absl::string_view text) { + os_->write(text.data(), text.size()); +@@ -93,7 +94,7 @@ class PosixWritableFile : public WritableFile { + bool WriteLine(absl::string_view text) { return Write(text) && Write("\n"); } + + private: +- util::Status status_; ++ absl::Status status_; + std::ostream *os_; + }; + +diff --git a/src/filesystem.h b/src/filesystem.h +index e572b4b..6e8e305 100644 +--- a/src/filesystem.h ++++ b/src/filesystem.h +@@ -23,7 +23,8 @@ + + #include "common.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + namespace filesystem { +@@ -33,7 +34,7 @@ class ReadableFile { + explicit ReadableFile(absl::string_view filename, bool is_binary = false) {} + virtual ~ReadableFile() {} + +- virtual util::Status status() const = 0; ++ virtual absl::Status status() const = 0; + virtual bool ReadLine(std::string *line) = 0; + virtual bool ReadAll(std::string *line) = 0; + }; +@@ -44,7 +45,7 @@ class WritableFile { + explicit WritableFile(absl::string_view filename, bool is_binary = false) {} + virtual ~WritableFile() {} + +- virtual util::Status status() const = 0; ++ virtual absl::Status status() const = 0; + virtual bool Write(absl::string_view text) = 0; + virtual bool WriteLine(absl::string_view text) = 0; + }; +diff --git a/src/filesystem_test.cc b/src/filesystem_test.cc +index 790e756..39ece99 100644 +--- a/src/filesystem_test.cc ++++ b/src/filesystem_test.cc +@@ -14,7 +14,7 @@ + + #include "filesystem.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" ++#include "absl/strings/str_cat.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/init.h b/src/init.h +index 090a2d9..acfda8a 100644 +--- a/src/init.h ++++ b/src/init.h +@@ -16,8 +16,8 @@ + #define INIT_H_ + + #include "common.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/flags/parse.h" ++#include "absl/flags/flag.h" ++#include "absl/flags/parse.h" + + ABSL_DECLARE_FLAG(int32, minloglevel); + +diff --git a/src/model_factory.cc b/src/model_factory.cc +index be99501..040c00c 100644 +--- a/src/model_factory.cc ++++ b/src/model_factory.cc +@@ -15,7 +15,7 @@ + #include "bpe_model.h" + #include "char_model.h" + #include "model_factory.h" +-#include "third_party/absl/memory/memory.h" ++#include "absl/memory/memory.h" + #include "unigram_model.h" + #include "word_model.h" + +diff --git a/src/model_interface.cc b/src/model_interface.cc +index c49be1e..22c6378 100644 +--- a/src/model_interface.cc ++++ b/src/model_interface.cc +@@ -16,8 +16,8 @@ + + #include "model_interface.h" + #include "sentencepiece_model.pb.h" +-#include "third_party/absl/memory/memory.h" +-#include "third_party/absl/strings/str_format.h" ++#include "absl/memory/memory.h" ++#include "absl/strings/str_format.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/model_interface.h b/src/model_interface.h +index aef5b53..c7858fb 100644 +--- a/src/model_interface.h ++++ b/src/model_interface.h +@@ -25,9 +25,10 @@ + #include "normalizer.h" + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/container/flat_hash_map.h" +-#include "third_party/absl/strings/string_view.h" +-#include "third_party/darts_clone/darts.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" ++#include "include/darts.h" + #include "util.h" + + namespace sentencepiece { +@@ -69,7 +70,7 @@ class ModelInterface { + + // Returns Status. + // Encode/Decode functions are valid only when status is OK. +- virtual util::Status status() const { return status_; } ++ virtual absl::Status status() const { return status_; } + + virtual const ModelProto &model_proto() const { return *model_proto_; } + +@@ -82,7 +83,7 @@ class ModelInterface { + // normally users do not need to call this function. This function is provided + // just in case that a user want to manually choose which encoder version to + // use. +- virtual util::Status SetEncoderVersion(EncoderVersion encoder_version) { ++ virtual absl::Status SetEncoderVersion(EncoderVersion encoder_version) { + encoder_version_ = encoder_version; + return util::OkStatus(); + } +@@ -261,7 +262,7 @@ class ModelInterface { + EncoderVersion encoder_version_ = EncoderVersion::kOptimized; + + // status. +- util::Status status_; ++ absl::Status status_; + }; + } // namespace sentencepiece + #endif // MODEL_INTERFACE_H_ +diff --git a/src/model_interface_test.cc b/src/model_interface_test.cc +index 69ee4e6..26a1e05 100644 +--- a/src/model_interface_test.cc ++++ b/src/model_interface_test.cc +@@ -15,7 +15,7 @@ + #include "model_factory.h" + #include "model_interface.h" + #include "testharness.h" +-#include "third_party/absl/container/flat_hash_map.h" ++#include "absl/container/flat_hash_map.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/normalizer.cc b/src/normalizer.cc +index 100b875..c553906 100644 +--- a/src/normalizer.cc ++++ b/src/normalizer.cc +@@ -18,11 +18,12 @@ + #include + + #include "common.h" +-#include "third_party/absl/memory/memory.h" +-#include "third_party/absl/strings/match.h" +-#include "third_party/absl/strings/string_view.h" +-#include "third_party/absl/strings/strip.h" +-#include "third_party/darts_clone/darts.h" ++#include "absl/memory/memory.h" ++#include "absl/strings/match.h" ++#include "absl/strings/string_view.h" ++#include "absl/strings/strip.h" ++#include "absl/status/status.h" ++#include "include/darts.h" + #include "util.h" + + namespace sentencepiece { +@@ -71,7 +72,7 @@ void Normalizer::Init() { + } + } + +-util::Status Normalizer::Normalize(absl::string_view input, ++absl::Status Normalizer::Normalize(absl::string_view input, + std::string *normalized, + std::vector *norm_to_orig) const { + norm_to_orig->clear(); +@@ -274,7 +275,7 @@ std::string Normalizer::EncodePrecompiledCharsMap( + } + + // static +-util::Status Normalizer::DecodePrecompiledCharsMap( ++absl::Status Normalizer::DecodePrecompiledCharsMap( + absl::string_view blob, absl::string_view *trie_blob, + absl::string_view *normalized, std::string *buffer) { + uint32 trie_blob_size = 0; +diff --git a/src/normalizer.h b/src/normalizer.h +index 622bbd2..21d1385 100644 +--- a/src/normalizer.h ++++ b/src/normalizer.h +@@ -24,8 +24,9 @@ + #include "common.h" + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/strings/string_view.h" +-#include "third_party/darts_clone/darts.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" ++#include "include/darts.h" + #include "util.h" + + namespace sentencepiece { +@@ -75,7 +76,7 @@ class Normalizer { + + // Returns Status. + // Normalizes function is valid only when status is OK. +- virtual util::Status status() const { return status_; } ++ virtual absl::Status status() const { return status_; } + + // Normalizes a plain utf8 string into an internal representation for + // Sentencepiece model. |norm_to_orig| stores the byte-alignment from +@@ -86,7 +87,7 @@ class Normalizer { + // - Adds a prefix space. + // - Replaces a space with a meta symbol. + // - Removing heading, tailing and other redundant spaces. +- virtual util::Status Normalize(absl::string_view input, ++ virtual absl::Status Normalize(absl::string_view input, + std::string *normalized, + std::vector *norm_to_orig) const; + +@@ -121,7 +122,7 @@ class Normalizer { + absl::string_view normalized); + + // Decodes blob into trie_blob and normalized string. +- static util::Status DecodePrecompiledCharsMap(absl::string_view blob, ++ static absl::Status DecodePrecompiledCharsMap(absl::string_view blob, + absl::string_view *trie_blob, + absl::string_view *normalized, + std::string *buffer = nullptr); +@@ -153,7 +154,7 @@ class Normalizer { + #endif + + // Normalizer's status. +- util::Status status_; ++ absl::Status status_; + }; + } // namespace normalizer + } // namespace sentencepiece +diff --git a/src/pretokenizer_for_training.cc b/src/pretokenizer_for_training.cc +index 049658e..8021511 100644 +--- a/src/pretokenizer_for_training.cc ++++ b/src/pretokenizer_for_training.cc +@@ -14,7 +14,7 @@ + #include + + #include "pretokenizer_for_training.h" +-#include "third_party/absl/strings/str_replace.h" ++#include "absl/strings/str_replace.h" + + namespace sentencepiece { + namespace pretokenizer { +diff --git a/src/pretokenizer_for_training.h b/src/pretokenizer_for_training.h +index 2d3bc82..b4a6de3 100644 +--- a/src/pretokenizer_for_training.h ++++ b/src/pretokenizer_for_training.h +@@ -21,7 +21,8 @@ + #include "common.h" + #include "sentencepiece.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + namespace pretokenizer { +@@ -30,7 +31,7 @@ class PretokenizerForTrainingInterface { + public: + PretokenizerForTrainingInterface() {} + virtual ~PretokenizerForTrainingInterface() {} +- virtual util::Status status() const = 0; ++ virtual absl::Status status() const = 0; + + // Puts kUPPBoundaryStr before and after the pre-tokenizer's segmentation + // when there are no spaces between these tokens. +diff --git a/src/pretokenizer_for_training_test.cc b/src/pretokenizer_for_training_test.cc +index 80f4787..de89fe3 100644 +--- a/src/pretokenizer_for_training_test.cc ++++ b/src/pretokenizer_for_training_test.cc +@@ -13,8 +13,9 @@ + // limitations under the License.! + #include "pretokenizer_for_training.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" ++#include "absl/strings/str_cat.h" + #include "trainer_interface.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + namespace pretokenizer { +@@ -28,7 +29,7 @@ class MockPretokenizer : public PretokenizerForTrainingInterface { + return spt_; + } + +- util::Status status() const override { return util::OkStatus(); } ++ absl::Status status() const override { return util::OkStatus(); } + + void SetOutput(const SentencePieceText &spt) { spt_ = spt; } + +diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc +index 1e4e7a0..78ae527 100644 +--- a/src/sentencepiece_processor.cc ++++ b/src/sentencepiece_processor.cc +@@ -23,14 +23,15 @@ + #include "normalizer.h" + #include "sentencepiece.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/memory/memory.h" +-#include "third_party/absl/strings/numbers.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" +-#include "third_party/absl/strings/str_replace.h" +-#include "third_party/absl/strings/str_split.h" +-#include "third_party/absl/strings/string_view.h" +-#include "third_party/absl/strings/strip.h" ++#include "absl/memory/memory.h" ++#include "absl/strings/numbers.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" ++#include "absl/strings/str_replace.h" ++#include "absl/strings/str_split.h" ++#include "absl/strings/string_view.h" ++#include "absl/strings/strip.h" ++#include "absl/status/status.h" + #include "unigram_model.h" + #include "util.h" + +@@ -52,7 +53,7 @@ const char kReplacementCharacter[] = "\xef\xbf\xbd"; + SentencePieceProcessor::SentencePieceProcessor() {} + SentencePieceProcessor::~SentencePieceProcessor() {} + +-util::Status SentencePieceProcessor::Load(absl::string_view filename) { ++absl::Status SentencePieceProcessor::Load(absl::string_view filename) { + auto model_proto = absl::make_unique(); + RETURN_IF_ERROR(io::LoadModelProto(filename, model_proto.get())); + return Load(std::move(model_proto)); +@@ -62,13 +63,13 @@ void SentencePieceProcessor::LoadOrDie(absl::string_view filename) { + CHECK_OK(Load(filename)); + } + +-util::Status SentencePieceProcessor::Load(const ModelProto &model_proto) { ++absl::Status SentencePieceProcessor::Load(const ModelProto &model_proto) { + auto model_proto_copy = absl::make_unique(); + *model_proto_copy = model_proto; + return Load(std::move(model_proto_copy)); + } + +-util::Status SentencePieceProcessor::LoadFromSerializedProto( ++absl::Status SentencePieceProcessor::LoadFromSerializedProto( + absl::string_view serialized) { + auto model_proto = absl::make_unique(); + CHECK_OR_RETURN( +@@ -76,7 +77,7 @@ util::Status SentencePieceProcessor::LoadFromSerializedProto( + return Load(std::move(model_proto)); + } + +-util::Status SentencePieceProcessor::Load( ++absl::Status SentencePieceProcessor::Load( + std::unique_ptr model_proto) { + model_proto_ = std::move(model_proto); + model_ = ModelFactory::Create(*model_proto_); +@@ -117,7 +118,7 @@ util::Status SentencePieceProcessor::Load( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::SetEncoderVersion( ++absl::Status SentencePieceProcessor::SetEncoderVersion( + EncoderVersion encoder_version) { + return model_->SetEncoderVersion(encoder_version); + } +@@ -126,17 +127,17 @@ EncoderVersion SentencePieceProcessor::GetEncoderVersion() const { + return model_->GetEncoderVersion(); + } + +-util::Status SentencePieceProcessor::SetEncodeExtraOptions( ++absl::Status SentencePieceProcessor::SetEncodeExtraOptions( + absl::string_view extra_options) { + return ParseExtraOptions(extra_options, &encode_extra_options_); + } + +-util::Status SentencePieceProcessor::SetDecodeExtraOptions( ++absl::Status SentencePieceProcessor::SetDecodeExtraOptions( + absl::string_view extra_options) { + return ParseExtraOptions(extra_options, &decode_extra_options_); + } + +-util::Status SentencePieceProcessor::status() const { ++absl::Status SentencePieceProcessor::status() const { + CHECK_OR_RETURN(model_) << "Model is not initialized."; + CHECK_OR_RETURN(normalizer_) << "Normalizer is not initialized."; + RETURN_IF_ERROR(model_->status()); +@@ -144,7 +145,7 @@ util::Status SentencePieceProcessor::status() const { + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::SetVocabulary( ++absl::Status SentencePieceProcessor::SetVocabulary( + const std::vector &valid_vocab) { + RETURN_IF_ERROR(status()); + +@@ -174,7 +175,7 @@ util::Status SentencePieceProcessor::SetVocabulary( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::ResetVocabulary() { ++absl::Status SentencePieceProcessor::ResetVocabulary() { + RETURN_IF_ERROR(status()); + for (auto &piece : *(model_proto_->mutable_pieces())) { + if (piece.type() == ModelProto::SentencePiece::UNUSED) +@@ -184,7 +185,7 @@ util::Status SentencePieceProcessor::ResetVocabulary() { + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::LoadVocabulary(absl::string_view filename, ++absl::Status SentencePieceProcessor::LoadVocabulary(absl::string_view filename, + int threshold) { + auto input = filesystem::NewReadableFile(filename); + RETURN_IF_ERROR(input->status()); +@@ -221,7 +222,7 @@ util::Status SentencePieceProcessor::LoadVocabulary(absl::string_view filename, + + ////////////////////////////////////////////////////////////// + // Simple API. +-util::Status SentencePieceProcessor::Encode( ++absl::Status SentencePieceProcessor::Encode( + absl::string_view input, std::vector *pieces) const { + CHECK_OR_RETURN_STATUS_STL(pieces); + +@@ -234,7 +235,7 @@ util::Status SentencePieceProcessor::Encode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::Encode(absl::string_view input, ++absl::Status SentencePieceProcessor::Encode(absl::string_view input, + std::vector *ids) const { + CHECK_OR_RETURN_STATUS_STL(ids); + +@@ -247,7 +248,7 @@ util::Status SentencePieceProcessor::Encode(absl::string_view input, + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::Decode( ++absl::Status SentencePieceProcessor::Decode( + const std::vector &pieces, std::string *detokenized) const { + CHECK_OR_RETURN_STATUS_STL(detokenized); + +@@ -258,7 +259,7 @@ util::Status SentencePieceProcessor::Decode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::Decode(const std::vector &ids, ++absl::Status SentencePieceProcessor::Decode(const std::vector &ids, + std::string *detokenized) const { + CHECK_OR_RETURN_STATUS_STL(detokenized); + +@@ -269,7 +270,7 @@ util::Status SentencePieceProcessor::Decode(const std::vector &ids, + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::NBestEncode( ++absl::Status SentencePieceProcessor::NBestEncode( + absl::string_view input, int nbest_size, + std::vector> *pieces) const { + CHECK_OR_RETURN_STATUS_STL(pieces); +@@ -287,7 +288,7 @@ util::Status SentencePieceProcessor::NBestEncode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::NBestEncode( ++absl::Status SentencePieceProcessor::NBestEncode( + absl::string_view input, int nbest_size, + std::vector> *ids) const { + CHECK_OR_RETURN_STATUS_STL(ids); +@@ -305,7 +306,7 @@ util::Status SentencePieceProcessor::NBestEncode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::SampleEncode( ++absl::Status SentencePieceProcessor::SampleEncode( + absl::string_view input, int nbest_size, float alpha, + std::vector *pieces) const { + CHECK_OR_RETURN_STATUS_STL(pieces); +@@ -319,7 +320,7 @@ util::Status SentencePieceProcessor::SampleEncode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::SampleEncode(absl::string_view input, ++absl::Status SentencePieceProcessor::SampleEncode(absl::string_view input, + int nbest_size, float alpha, + std::vector *ids) const { + CHECK_OR_RETURN_STATUS_STL(ids); +@@ -333,7 +334,7 @@ util::Status SentencePieceProcessor::SampleEncode(absl::string_view input, + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::PopulateSentencePieceText( ++absl::Status SentencePieceProcessor::PopulateSentencePieceText( + absl::string_view input, absl::string_view normalized, + const std::vector &norm_to_orig, const EncodeResult &result, + SentencePieceText *spt) const { +@@ -424,7 +425,7 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText( + return util::OkStatus(); + } // namespace sentencepiece + +-util::Status SentencePieceProcessor::Encode(absl::string_view input, ++absl::Status SentencePieceProcessor::Encode(absl::string_view input, + SentencePieceText *spt) const { + CHECK_OR_RETURN_STATUS_PROTO(spt); + +@@ -439,7 +440,7 @@ util::Status SentencePieceProcessor::Encode(absl::string_view input, + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::NBestEncode( ++absl::Status SentencePieceProcessor::NBestEncode( + absl::string_view input, int nbest_size, + NBestSentencePieceText *nbest_spt) const { + CHECK_OR_RETURN_STATUS_PROTO(nbest_spt); +@@ -464,7 +465,7 @@ util::Status SentencePieceProcessor::NBestEncode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::SampleEncode( ++absl::Status SentencePieceProcessor::SampleEncode( + absl::string_view input, int nbest_size, float alpha, + SentencePieceText *spt) const { + CHECK_OR_RETURN_STATUS_PROTO(spt); +@@ -503,7 +504,7 @@ util::Status SentencePieceProcessor::SampleEncode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::SampleEncodeAndScore( ++absl::Status SentencePieceProcessor::SampleEncodeAndScore( + absl::string_view input, int samples, float theta, bool wor, + bool include_best, NBestSentencePieceText *samples_spt) const { + CHECK_OR_RETURN(model_->IsSampleEncodeAndScoreAvailable()) +@@ -527,7 +528,7 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, ++absl::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, + float theta, + float *entropy) const { + CHECK_OR_RETURN(model_->IsCalculateEntropyAvailable()) +@@ -540,7 +541,7 @@ util::Status SentencePieceProcessor::CalculateEntropy(absl::string_view input, + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::Decode( ++absl::Status SentencePieceProcessor::Decode( + const std::vector &pieces, SentencePieceText *spt) const { + CHECK_OR_RETURN_STATUS_PROTO(spt); + +@@ -591,7 +592,7 @@ util::Status SentencePieceProcessor::Decode( + }; + + auto ProcessBytePieces = [&](int token_index_begin, +- int token_index_end) -> util::Status { ++ int token_index_end) -> absl::Status { + if (token_index_begin >= token_index_end) { + return util::OkStatus(); + } +@@ -661,14 +662,14 @@ util::Status SentencePieceProcessor::Decode( + return util::OkStatus(); + } + +-util::Status SentencePieceProcessor::Decode(const std::vector &ids, ++absl::Status SentencePieceProcessor::Decode(const std::vector &ids, + SentencePieceText *spt) const { + std::vector pieces; + const int num_pieces = GetPieceSize(); + pieces.reserve(ids.size()); + for (const int id : ids) { + if (id < 0 || id >= num_pieces) { +- return util::Status(util::StatusCode::kOutOfRange, ++ return absl::Status(absl::StatusCode::kOutOfRange, + absl::StrCat("Invalid id: ", id)); + } + pieces.emplace_back(IdToPiece(id)); +@@ -783,7 +784,7 @@ int SentencePieceProcessor::pad_id() const { + } + + // static +-util::Status SentencePieceProcessor::ApplyExtraOptions( ++absl::Status SentencePieceProcessor::ApplyExtraOptions( + const std::vector &extra_options, + SentencePieceText *spt) const { + for (const auto &extra_option : extra_options) { +@@ -818,7 +819,7 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( + } + + // static +-util::Status SentencePieceProcessor::ParseExtraOptions( ++absl::Status SentencePieceProcessor::ParseExtraOptions( + absl::string_view _extra_option, + std::vector *extra_options) const { + absl::string_view extra_option(_extra_option.data(), _extra_option.size()); +@@ -877,7 +878,7 @@ void SetRandomGeneratorSeed(unsigned int seed); + + namespace io { + +-util::Status LoadModelProto(absl::string_view filename, ++absl::Status LoadModelProto(absl::string_view filename, + ModelProto *model_proto) { + if (filename.empty()) { + return util::NotFoundError("model file path should not be empty."); +@@ -893,7 +894,7 @@ util::Status LoadModelProto(absl::string_view filename, + return util::OkStatus(); + } + +-util::Status SaveModelProto(absl::string_view filename, ++absl::Status SaveModelProto(absl::string_view filename, + const ModelProto &model_proto) { + if (filename.empty()) { + return util::NotFoundError("model file path should not be empty."); +diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h +index e8bd5f5..346fb0e 100644 +--- a/src/sentencepiece_processor.h ++++ b/src/sentencepiece_processor.h +@@ -20,9 +20,10 @@ + #include + #include + #include ++#include "absl/status/status.h" + + #if defined(_USE_INTERNAL_STRING_VIEW) +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" + #elif defined(_USE_TF_STRING_VIEW) + #include "absl/strings/string_view.h" + #else +@@ -185,7 +186,7 @@ class SentencePieceProcessor { + + // Loads model from `filename`. + // Returns false if `filename` cannot be loaded. +- virtual util::Status Load(absl::string_view filename); ++ virtual absl::Status Load(absl::string_view filename); + + // Loads model from `filename`. + // Crash if `filename` cannot be loaded. +@@ -193,24 +194,24 @@ class SentencePieceProcessor { + + // Loads model from `model_proto`. + // `model_proto` is copied. +- virtual util::Status Load(const ModelProto &model_proto); ++ virtual absl::Status Load(const ModelProto &model_proto); + + // Loads model from `model_proto`. + // `model_proto` is moved. +- virtual util::Status Load(std::unique_ptr model_proto); ++ virtual absl::Status Load(std::unique_ptr model_proto); + + // Loads model from `serialized`, which is a string-serialized model proto. + // Useful to load the model from a platform independent blob object. +- virtual util::Status LoadFromSerializedProto(absl::string_view serialized); ++ virtual absl::Status LoadFromSerializedProto(absl::string_view serialized); + + // Returns the status. Encode/Decode methods are valid when status is OK. +- virtual util::Status status() const; ++ virtual absl::Status status() const; + + // Sets encode extra_option sequence. +- virtual util::Status SetEncodeExtraOptions(absl::string_view extra_option); ++ virtual absl::Status SetEncodeExtraOptions(absl::string_view extra_option); + + // Sets decode extra_option sequence. +- virtual util::Status SetDecodeExtraOptions(absl::string_view extra_option); ++ virtual absl::Status SetDecodeExtraOptions(absl::string_view extra_option); + + ////////////////////////////////////////////////////////////// + // Vocabulary restriction. +@@ -219,41 +220,41 @@ class SentencePieceProcessor { + + // Restricts the vocabulary set. + // The input sentences are encoded into the tokens in `valid_vocab`. +- virtual util::Status SetVocabulary( ++ virtual absl::Status SetVocabulary( + const std::vector &valid_vocab); + + // Reverts the vocabulary restriction. +- virtual util::Status ResetVocabulary(); ++ virtual absl::Status ResetVocabulary(); + + // Loads the valid vocabulary set from `filename` in TSV format. + // Format: . + // Any token with frequency < threshold will be treated as OOV. +- virtual util::Status LoadVocabulary(absl::string_view filename, ++ virtual absl::Status LoadVocabulary(absl::string_view filename, + int threshold); + + ////////////////////////////////////////////////////////////// + // Simple API. + // + // Given a UTF8 input, encodes it into a sequence of sentence pieces. +- virtual util::Status Encode(absl::string_view input, ++ virtual absl::Status Encode(absl::string_view input, + std::vector *pieces) const; + + // Given a UTF8 input, encodes it into a sequence of ids. +- virtual util::Status Encode(absl::string_view input, ++ virtual absl::Status Encode(absl::string_view input, + std::vector *ids) const; + + // Given a sequence of pieces, decodes it into a detokenized output. +- virtual util::Status Decode(const std::vector &pieces, ++ virtual absl::Status Decode(const std::vector &pieces, + std::string *detokenized) const; + + // Given a sequence of ids, decodes it into a detokenized output. +- virtual util::Status Decode(const std::vector &ids, ++ virtual absl::Status Decode(const std::vector &ids, + std::string *detokenized) const; + + // Sets the encoder version. Normally users do not need to call this function. + // But they can call this fucntion just in case if they want to fall back to + // the original encoder. +- virtual util::Status SetEncoderVersion(EncoderVersion encoder_version); ++ virtual absl::Status SetEncoderVersion(EncoderVersion encoder_version); + + // Returns the current encoder version in use. + virtual EncoderVersion GetEncoderVersion() const; +@@ -261,12 +262,12 @@ class SentencePieceProcessor { + ////////////////////////////////////////////////////////////// + // NBest API. + // Same as Encode, but returns nbest results. +- virtual util::Status NBestEncode( ++ virtual absl::Status NBestEncode( + absl::string_view input, int nbest_size, + std::vector> *pieces) const; + + // Same as Encode, but returns nbest results. +- virtual util::Status NBestEncode(absl::string_view input, int nbest_size, ++ virtual absl::Status NBestEncode(absl::string_view input, int nbest_size, + std::vector> *ids) const; + + ////////////////////////////////////////////////////////////// +@@ -289,12 +290,12 @@ class SentencePieceProcessor { + // in https://arxiv.org/abs/1910.13267 + // Nbest-based sampling is not supported so nbest_size parameter is ignored in + // BPE. +- virtual util::Status SampleEncode(absl::string_view input, int nbest_size, ++ virtual absl::Status SampleEncode(absl::string_view input, int nbest_size, + float alpha, + std::vector *pieces) const; + + // Same as above, but returns a sequence of ids. +- virtual util::Status SampleEncode(absl::string_view input, int nbest_size, ++ virtual absl::Status SampleEncode(absl::string_view input, int nbest_size, + float alpha, std::vector *ids) const; + + ////////////////////////////////////////////////////////////// +@@ -303,16 +304,16 @@ class SentencePieceProcessor { + // and internal sentencepiece sequence. + // + // Given a UTF8 input, encodes it into SentencePieceText. +- virtual util::Status Encode(absl::string_view input, ++ virtual absl::Status Encode(absl::string_view input, + SentencePieceText *spt) const; + + // Same as above, but returns NBestSentencePieceText. +- virtual util::Status NBestEncode(absl::string_view input, int nbest_size, ++ virtual absl::Status NBestEncode(absl::string_view input, int nbest_size, + NBestSentencePieceText *nbest_spt) const; + + // Same as above, but samples one segmentation from the hypotheses + // (Lattice). +- virtual util::Status SampleEncode(absl::string_view input, int nbest_size, ++ virtual absl::Status SampleEncode(absl::string_view input, int nbest_size, + float alpha, SentencePieceText *spt) const; + + // Sample `samples` segmentations from the segmentation lattice. +@@ -323,21 +324,21 @@ class SentencePieceProcessor { + // If `include_best` is true, the best tokenization is always included in the + // sample, and the remaining elements are sampled excluding the best. + // This method is only available in Unigram mode. +- virtual util::Status SampleEncodeAndScore( ++ virtual absl::Status SampleEncodeAndScore( + absl::string_view input, int samples, float theta, bool wor, + bool include_best, NBestSentencePieceText *samples_spt) const; + + // Calculate entropy of possible tokenization. + // Only available in unigram mode. +- virtual util::Status CalculateEntropy(absl::string_view input, float theta, ++ virtual absl::Status CalculateEntropy(absl::string_view input, float theta, + float *entropy) const; + + // Given a sequence of pieces, decodes it into SentencePieceText. +- virtual util::Status Decode(const std::vector &pieces, ++ virtual absl::Status Decode(const std::vector &pieces, + SentencePieceText *spt) const; + + // Given a sequence of ids, decodes it into SentencePieceText. +- virtual util::Status Decode(const std::vector &ids, ++ virtual absl::Status Decode(const std::vector &ids, + SentencePieceText *spt) const; + + ////////////////////////////////////////////////////////////// +@@ -487,13 +488,13 @@ class SentencePieceProcessor { + private: + enum ExtraOption { REVERSE, BOS, EOS }; + +- util::Status ParseExtraOptions(absl::string_view extra_option, ++ absl::Status ParseExtraOptions(absl::string_view extra_option, + std::vector *extra_options) const; + +- util::Status ApplyExtraOptions(const std::vector &extra_options, ++ absl::Status ApplyExtraOptions(const std::vector &extra_options, + SentencePieceText *spt) const; + +- util::Status PopulateSentencePieceText( ++ absl::Status PopulateSentencePieceText( + absl::string_view input, absl::string_view normalized, + const std::vector &norm_to_orig, + const std::vector> &result, +@@ -526,10 +527,10 @@ namespace io { + // io::LoadModelProto("//path/spm.model", model_proto.get()); + // SentencePieceProcessor sp; + // CHECK_OK(sp.Load(std::move(model_proto))); +-util::Status LoadModelProto(absl::string_view, ModelProto *model_proto); ++absl::Status LoadModelProto(absl::string_view, ModelProto *model_proto); + + // Saves `model_proto` as `filename`. +-util::Status SaveModelProto(absl::string_view, const ModelProto &model_proto); ++absl::Status SaveModelProto(absl::string_view, const ModelProto &model_proto); + } // namespace io + #endif // SWIG + } // namespace sentencepiece +diff --git a/src/sentencepiece_processor_test.cc b/src/sentencepiece_processor_test.cc +index 373e73e..829c3d4 100644 +--- a/src/sentencepiece_processor_test.cc ++++ b/src/sentencepiece_processor_test.cc +@@ -23,10 +23,10 @@ + #include "sentencepiece_processor.h" + #include "sentencepiece_trainer.h" + #include "testharness.h" +-#include "third_party/absl/container/flat_hash_map.h" +-#include "third_party/absl/memory/memory.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/memory/memory.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/string_view.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc +index b9fe64f..5b33cd7 100644 +--- a/src/sentencepiece_trainer.cc ++++ b/src/sentencepiece_trainer.cc +@@ -22,12 +22,13 @@ + #include "sentencepiece_model.pb.h" + #include "sentencepiece_trainer.h" + #include "spec_parser.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/strings/numbers.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_split.h" +-#include "third_party/absl/strings/string_view.h" +-#include "third_party/absl/strings/strip.h" ++#include "absl/flags/flag.h" ++#include "absl/strings/numbers.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_split.h" ++#include "absl/strings/string_view.h" ++#include "absl/strings/strip.h" ++#include "absl/status/status.h" + #include "trainer_factory.h" + #include "util.h" + +@@ -37,7 +38,7 @@ static constexpr char kDefaultNormalizerName[] = "nmt_nfkc"; + } // namespace + + // static +-util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, ++absl::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, + SentenceIterator *sentence_iterator, + std::string *serialized_model_proto) { + NormalizerSpec normalizer_spec; +@@ -45,7 +46,7 @@ util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, + serialized_model_proto); + } + +-util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, ++absl::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, + const NormalizerSpec &normalizer_spec, + SentenceIterator *sentence_iterator, + std::string *serialized_model_proto) { +@@ -55,7 +56,7 @@ util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec, + } + + // static +-util::Status SentencePieceTrainer::Train( ++absl::Status SentencePieceTrainer::Train( + const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec, + const NormalizerSpec &denormalizer_spec, + SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { +@@ -97,7 +98,7 @@ NormalizerSpec SentencePieceTrainer::GetNormalizerSpec(absl::string_view name) { + } + + // static +-util::Status SentencePieceTrainer::MergeSpecsFromArgs( ++absl::Status SentencePieceTrainer::MergeSpecsFromArgs( + absl::string_view args, TrainerSpec *trainer_spec, + NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec) { + CHECK_OR_RETURN(trainer_spec) << "`trainer_spec` must not be null."; +@@ -125,7 +126,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( + } + + // static +-util::Status SentencePieceTrainer::MergeSpecsFromArgs( ++absl::Status SentencePieceTrainer::MergeSpecsFromArgs( + const std::unordered_map &kwargs, + TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, + NormalizerSpec *denormalizer_spec) { +@@ -171,7 +172,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( + } + + // static +-util::Status SentencePieceTrainer::Train(absl::string_view args, ++absl::Status SentencePieceTrainer::Train(absl::string_view args, + SentenceIterator *sentence_iterator, + std::string *serialized_model_proto) { + LOG(INFO) << "Running command: " << args.data(); +@@ -185,7 +186,7 @@ util::Status SentencePieceTrainer::Train(absl::string_view args, + } + + // static +-util::Status SentencePieceTrainer::Train( ++absl::Status SentencePieceTrainer::Train( + const std::unordered_map &kwargs, + SentenceIterator *sentence_iterator, std::string *serialized_model_proto) { + TrainerSpec trainer_spec; +@@ -198,7 +199,7 @@ util::Status SentencePieceTrainer::Train( + } + + // static +-util::Status SentencePieceTrainer::PopulateNormalizerSpec( ++absl::Status SentencePieceTrainer::PopulateNormalizerSpec( + NormalizerSpec *normalizer_spec, bool is_denormalizer) { + CHECK_OR_RETURN(normalizer_spec); + +@@ -226,7 +227,7 @@ util::Status SentencePieceTrainer::PopulateNormalizerSpec( + } + + // static +-util::Status SentencePieceTrainer::PopulateModelTypeFromString( ++absl::Status SentencePieceTrainer::PopulateModelTypeFromString( + absl::string_view type, TrainerSpec *spec) { + static const std::unordered_map + kModelTypeMap = {{"unigram", TrainerSpec::UNIGRAM}, +@@ -239,7 +240,7 @@ util::Status SentencePieceTrainer::PopulateModelTypeFromString( + return util::OkStatus(); + } + +- return util::StatusBuilder(util::StatusCode::kInternal, GTL_LOC) ++ return util::StatusBuilder(absl::StatusCode::kInternal, GTL_LOC) + << "\"" << type << "\" is not found in TrainerSpec"; + } + +@@ -248,7 +249,7 @@ const pretokenizer::PretokenizerForTrainingInterface *g_pretokenizer = nullptr; + } // namespace + + // static +-util::Status SentencePieceTrainer::SetPretokenizerForTraining( ++absl::Status SentencePieceTrainer::SetPretokenizerForTraining( + const pretokenizer::PretokenizerForTrainingInterface *pretokenizer) { + g_pretokenizer = pretokenizer; + return util::OkStatus(); +diff --git a/src/sentencepiece_trainer.h b/src/sentencepiece_trainer.h +index bb74ab9..ec6cf93 100644 +--- a/src/sentencepiece_trainer.h ++++ b/src/sentencepiece_trainer.h +@@ -19,6 +19,7 @@ + #include + + #include "sentencepiece_processor.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + +@@ -46,7 +47,7 @@ class SentenceIterator { + virtual bool done() const = 0; + virtual void Next() = 0; + virtual const std::string &value() const = 0; +- virtual util::Status status() const = 0; ++ virtual absl::Status status() const = 0; + }; + + class SentencePieceTrainer { +@@ -54,14 +55,14 @@ class SentencePieceTrainer { + // Trains SentencePiece model with `trainer_spec`. + // Default `normalizer_spec` is used. + // When `sentence_iterator` is passed, load sentences from the iterator. +- static util::Status Train(const TrainerSpec &trainer_spec, ++ static absl::Status Train(const TrainerSpec &trainer_spec, + SentenceIterator *sentence_iterator = nullptr, + std::string *serialized_model_proto = nullptr); + + // Trains SentencePiece model with `trainer_spec` and + // `normalizer_spec`. + // When `sentence_iterator` is passed, load sentences from the iterator. +- static util::Status Train(const TrainerSpec &trainer_spec, ++ static absl::Status Train(const TrainerSpec &trainer_spec, + const NormalizerSpec &normalizer_spec, + SentenceIterator *sentence_iterator = nullptr, + std::string *serialized_model_proto = nullptr); +@@ -69,7 +70,7 @@ class SentencePieceTrainer { + // Trains SentencePiece model with `trainer_spec`, `normalizer_spec` + // and `denormalizer_spec`. + // When `sentence_iterator` is passed, load sentences from the iterator. +- static util::Status Train(const TrainerSpec &trainer_spec, ++ static absl::Status Train(const TrainerSpec &trainer_spec, + const NormalizerSpec &normalizer_spec, + const NormalizerSpec &denormalizer_spec, + SentenceIterator *sentence_iterator = nullptr, +@@ -78,13 +79,13 @@ class SentencePieceTrainer { + // e.g., + // '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram' + // When `sentence_iterator` is passed, load sentences from the iterator. +- static util::Status Train(absl::string_view args, ++ static absl::Status Train(absl::string_view args, + SentenceIterator *sentence_iterator = nullptr, + std::string *serialized_model_proto = nullptr); + + // Trains SentencePiece model with mapin `kwargs`. + // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...} +- static util::Status Train( ++ static absl::Status Train( + const std::unordered_map &kwargs, + SentenceIterator *sentence_iterator = nullptr, + std::string *serialized_model_proto = nullptr); +@@ -96,19 +97,19 @@ class SentencePieceTrainer { + + // Populates necessary fields (precompiled_charmap) from + // `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`. +- static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec, ++ static absl::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec, + bool is_denormalizer = false); + + // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the + // std::unordered_map in `kargs`. +- static util::Status MergeSpecsFromArgs( ++ static absl::Status MergeSpecsFromArgs( + const std::unordered_map &kwargs, + TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, + NormalizerSpec *denormalizer_spec); + + // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the + // command line flags in `args`. +- static util::Status MergeSpecsFromArgs(absl::string_view args, ++ static absl::Status MergeSpecsFromArgs(absl::string_view args, + TrainerSpec *trainer_spec, + NormalizerSpec *normalizer_spec, + NormalizerSpec *denormalizer_spec); +@@ -116,7 +117,7 @@ class SentencePieceTrainer { + // Injects global pre-tokenizer that are applied in training time. + // Pretokenizer is only used for extracting pieces. + // TODO(taku): It would be better to inject per `trainer_spec`. +- static util::Status SetPretokenizerForTraining( ++ static absl::Status SetPretokenizerForTraining( + const pretokenizer::PretokenizerForTrainingInterface *pretokenizer); + + // Returns the current pretokenizer. if no pretokenizer is defined, returns +@@ -129,17 +130,17 @@ class SentencePieceTrainer { + // with comma-separated values. `field_name` must not be a nested message. + // The body of these functions are automatically generated with + // data/gen_spec_parser.pl +- static util::Status SetProtoField(const std::string &name, ++ static absl::Status SetProtoField(const std::string &name, + const std::string &value, + TrainerSpec *message); + +- static util::Status SetProtoField(const std::string &name, ++ static absl::Status SetProtoField(const std::string &name, + const std::string &value, + NormalizerSpec *message); + + // Populates model type from string representation, e.g., "bpe". + // Supported model: "unigram", "bpe", "word", "char". +- static util::Status PopulateModelTypeFromString(absl::string_view type, ++ static absl::Status PopulateModelTypeFromString(absl::string_view type, + TrainerSpec *trainer_spec); + + private: +diff --git a/src/sentencepiece_trainer_test.cc b/src/sentencepiece_trainer_test.cc +index e44e66b..00c8d08 100644 +--- a/src/sentencepiece_trainer_test.cc ++++ b/src/sentencepiece_trainer_test.cc +@@ -16,7 +16,8 @@ + #include "sentencepiece_model.pb.h" + #include "sentencepiece_trainer.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" ++#include "absl/strings/str_cat.h" ++#include "absl/status/status.h" + #include "util.h" + + namespace sentencepiece { +@@ -109,7 +110,7 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) { + bool done() const override { return idx_ == vec_.size(); } + void Next() override { ++idx_; } + const std::string &value() const override { return vec_[idx_]; } +- util::Status status() const override { return util::OkStatus(); } ++ absl::Status status() const override { return util::OkStatus(); } + + private: + std::vector vec_; +diff --git a/src/spec_parser.h b/src/spec_parser.h +index 2c5a95b..259c45d 100644 +--- a/src/spec_parser.h ++++ b/src/spec_parser.h +@@ -19,8 +19,9 @@ + #include + + #include "sentencepiece_processor.h" +-#include "third_party/absl/strings/ascii.h" +-#include "third_party/absl/strings/str_split.h" ++#include "absl/strings/ascii.h" ++#include "absl/strings/str_split.h" ++#include "absl/status/status.h" + #include "util.h" + + namespace sentencepiece { +@@ -49,7 +50,7 @@ namespace sentencepiece { + if (name == #param_name) { \ + int32 v; \ + if (!string_util::lexical_cast(value, &v)) \ +- return util::StatusBuilder(util::StatusCode::kInvalidArgument, GTL_LOC) \ ++ return util::StatusBuilder(absl::StatusCode::kInvalidArgument, GTL_LOC) \ + << "cannot parse \"" << value << "\" as int."; \ + message->set_##param_name(v); \ + return util::OkStatus(); \ +@@ -59,7 +60,7 @@ namespace sentencepiece { + if (name == #param_name) { \ + uint64 v; \ + if (!string_util::lexical_cast(value, &v)) \ +- return util::StatusBuilder(util::StatusCode::kInvalidArgument, GTL_LOC) \ ++ return util::StatusBuilder(absl::StatusCode::kInvalidArgument, GTL_LOC) \ + << "cannot parse \"" << value << "\" as int."; \ + message->set_##param_name(v); \ + return util::OkStatus(); \ +@@ -69,7 +70,7 @@ namespace sentencepiece { + if (name == #param_name) { \ + double v; \ + if (!string_util::lexical_cast(value, &v)) \ +- return util::StatusBuilder(util::StatusCode::kInvalidArgument, GTL_LOC) \ ++ return util::StatusBuilder(absl::StatusCode::kInvalidArgument, GTL_LOC) \ + << "cannot parse \"" << value << "\" as int."; \ + message->set_##param_name(v); \ + return util::OkStatus(); \ +@@ -79,7 +80,7 @@ namespace sentencepiece { + if (name == #param_name) { \ + bool v; \ + if (!string_util::lexical_cast(value.empty() ? "true" : value, &v)) \ +- return util::StatusBuilder(util::StatusCode::kInvalidArgument, GTL_LOC) \ ++ return util::StatusBuilder(absl::StatusCode::kInvalidArgument, GTL_LOC) \ + << "cannot parse \"" << value << "\" as bool."; \ + message->set_##param_name(v); \ + return util::OkStatus(); \ +@@ -89,7 +90,7 @@ namespace sentencepiece { + if (name == #param_name) { \ + const auto it = map_name.find(absl::AsciiStrToUpper(value)); \ + if (it == map_name.end()) \ +- return util::StatusBuilder(util::StatusCode::kInvalidArgument, GTL_LOC) \ ++ return util::StatusBuilder(absl::StatusCode::kInvalidArgument, GTL_LOC) \ + << "unknown enumeration value of \"" << value << "\" as " \ + << #map_name; \ + message->set_##param_name(it->second); \ +@@ -186,7 +187,7 @@ inline std::string PrintProto(const NormalizerSpec &message, + return os.str(); + } + +-util::Status SentencePieceTrainer::SetProtoField(const std::string &name, ++absl::Status SentencePieceTrainer::SetProtoField(const std::string &name, + const std::string &value, + TrainerSpec *message) { + CHECK_OR_RETURN(message); +@@ -239,11 +240,11 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name, + PARSE_STRING(pad_piece); + PARSE_STRING(unk_surface); + +- return util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC) ++ return util::StatusBuilder(absl::StatusCode::kNotFound, GTL_LOC) + << "unknown field name \"" << name << "\" in TrainerSpec."; + } + +-util::Status SentencePieceTrainer::SetProtoField(const std::string &name, ++absl::Status SentencePieceTrainer::SetProtoField(const std::string &name, + const std::string &value, + NormalizerSpec *message) { + CHECK_OR_RETURN(message); +@@ -255,7 +256,7 @@ util::Status SentencePieceTrainer::SetProtoField(const std::string &name, + PARSE_BOOL(escape_whitespaces); + PARSE_STRING(normalization_rule_tsv); + +- return util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC) ++ return util::StatusBuilder(absl::StatusCode::kNotFound, GTL_LOC) + << "unknown field name \"" << name << "\" in NormalizerSpec."; + } + +diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc +index 3382ddc..9dda65c 100644 +--- a/src/spm_decode_main.cc ++++ b/src/spm_decode_main.cc +@@ -21,8 +21,8 @@ + #include "init.h" + #include "sentencepiece.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/strings/str_split.h" ++#include "absl/flags/flag.h" ++#include "absl/strings/str_split.h" + #include "util.h" + + ABSL_FLAG(std::string, model, "", "model file name"); +diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc +index 4d12a38..29b7458 100644 +--- a/src/spm_encode_main.cc ++++ b/src/spm_encode_main.cc +@@ -21,10 +21,10 @@ + #include "init.h" + #include "sentencepiece.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/container/flat_hash_map.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/flags/flag.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" + #include "trainer_interface.h" + + ABSL_FLAG(std::string, model, "", "model file name"); +diff --git a/src/spm_export_vocab_main.cc b/src/spm_export_vocab_main.cc +index b5d93cb..70a65c1 100644 +--- a/src/spm_export_vocab_main.cc ++++ b/src/spm_export_vocab_main.cc +@@ -20,7 +20,7 @@ + #include "init.h" + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/flags/flag.h" ++#include "absl/flags/flag.h" + + ABSL_FLAG(std::string, output, "", "Output filename"); + ABSL_FLAG(std::string, model, "", "input model file name"); +diff --git a/src/spm_normalize_main.cc b/src/spm_normalize_main.cc +index 96da360..8c541b8 100644 +--- a/src/spm_normalize_main.cc ++++ b/src/spm_normalize_main.cc +@@ -21,7 +21,7 @@ + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" + #include "sentencepiece_trainer.h" +-#include "third_party/absl/flags/flag.h" ++#include "absl/flags/flag.h" + + ABSL_FLAG(std::string, model, "", "Model file name"); + ABSL_FLAG(bool, use_internal_normalization, false, +diff --git a/src/spm_train_main.cc b/src/spm_train_main.cc +index baf8dbf..ba1e811 100644 +--- a/src/spm_train_main.cc ++++ b/src/spm_train_main.cc +@@ -18,10 +18,10 @@ + #include "init.h" + #include "sentencepiece_model.pb.h" + #include "sentencepiece_trainer.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/strings/ascii.h" +-#include "third_party/absl/strings/str_join.h" +-#include "third_party/absl/strings/str_split.h" ++#include "absl/flags/flag.h" ++#include "absl/strings/ascii.h" ++#include "absl/strings/str_join.h" ++#include "absl/strings/str_split.h" + #include "util.h" + + using sentencepiece::NormalizerSpec; +diff --git a/src/testharness.cc b/src/testharness.cc +index f6b1efe..daf2d14 100644 +--- a/src/testharness.cc ++++ b/src/testharness.cc +@@ -26,7 +26,7 @@ + #include + + #include "common.h" +-#include "third_party/absl/strings/str_cat.h" ++#include "absl/strings/str_cat.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/testharness.h b/src/testharness.h +index 9879b06..98317ad 100644 +--- a/src/testharness.h ++++ b/src/testharness.h +@@ -21,9 +21,9 @@ + #include + + #include "common.h" +-#include "third_party/absl/flags/flag.h" +-#include "third_party/absl/flags/parse.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/flags/flag.h" ++#include "absl/flags/parse.h" ++#include "absl/strings/string_view.h" + + ABSL_DECLARE_FLAG(std::string, test_tmpdir); + ABSL_DECLARE_FLAG(std::string, test_srcdir); +diff --git a/src/trainer_factory.cc b/src/trainer_factory.cc +index d1d2541..ff594d0 100644 +--- a/src/trainer_factory.cc ++++ b/src/trainer_factory.cc +@@ -14,7 +14,7 @@ + + #include "bpe_model_trainer.h" + #include "char_model_trainer.h" +-#include "third_party/absl/memory/memory.h" ++#include "absl/memory/memory.h" + #include "trainer_factory.h" + #include "unigram_model_trainer.h" + #include "word_model_trainer.h" +diff --git a/src/trainer_interface.cc b/src/trainer_interface.cc +index a3a4b74..3e441ec 100644 +--- a/src/trainer_interface.cc ++++ b/src/trainer_interface.cc +@@ -26,13 +26,14 @@ + #include "normalizer.h" + #include "sentencepiece_processor.h" + #include "sentencepiece_trainer.h" +-#include "third_party/absl/container/flat_hash_map.h" +-#include "third_party/absl/memory/memory.h" +-#include "third_party/absl/strings/numbers.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_format.h" +-#include "third_party/absl/strings/str_join.h" +-#include "third_party/absl/strings/str_split.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/memory/memory.h" ++#include "absl/strings/numbers.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_format.h" ++#include "absl/strings/str_join.h" ++#include "absl/strings/str_split.h" ++#include "absl/status/status.h" + #include "trainer_interface.h" + #include "unicode_script.h" + #include "util.h" +@@ -49,7 +50,7 @@ const char32 TrainerInterface::kUPPBoundaryChar = L'\u0009'; + const char TrainerInterface::kUPPBoundaryStr[] = "\t"; + + namespace { +-util::Status VerifySpec(const TrainerSpec &trainer_spec) { ++absl::Status VerifySpec(const TrainerSpec &trainer_spec) { + CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0); + + if (trainer_spec.model_type() == TrainerSpec::UNIGRAM || +@@ -164,7 +165,7 @@ bool MultiFileSentenceIterator::done() const { + return (!read_done_ && file_index_ == files_.size()); + } + +-util::Status MultiFileSentenceIterator::status() const { ++absl::Status MultiFileSentenceIterator::status() const { + CHECK_OR_RETURN(fp_); + return fp_->status(); + } +@@ -212,7 +213,7 @@ bool TrainerInterface::IsValidSentencePiece( + } + + constexpr unicode_script::ScriptType kAnyType = +- static_cast(-1); ++ static_cast(0); + + unicode_script::ScriptType prev_script = kAnyType; + bool all_whitespace_piece = +@@ -296,7 +297,7 @@ bool TrainerInterface::IsValidSentencePiece( + return true; + } + +-util::Status TrainerInterface::LoadSentences() { ++absl::Status TrainerInterface::LoadSentences() { + RETURN_IF_ERROR(status()); + CHECK_OR_RETURN(sentences_.empty()); + CHECK_OR_RETURN(required_chars_.empty()); +@@ -537,7 +538,7 @@ void TrainerInterface::SplitSentencesByWhitespace() { + LOG(INFO) << "Done! " << sentences_.size(); + } + +-util::Status TrainerInterface::Serialize(ModelProto *model_proto) const { ++absl::Status TrainerInterface::Serialize(ModelProto *model_proto) const { + RETURN_IF_ERROR(status()); + + // Duplicated sentencepiece is not allowed. +@@ -611,7 +612,7 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const { + return util::OkStatus(); + } + +-util::Status TrainerInterface::SaveModel(absl::string_view filename) const { ++absl::Status TrainerInterface::SaveModel(absl::string_view filename) const { + LOG(INFO) << "Saving model: " << filename; + ModelProto model_proto; + RETURN_IF_ERROR(Serialize(&model_proto)); +@@ -622,7 +623,7 @@ util::Status TrainerInterface::SaveModel(absl::string_view filename) const { + return util::OkStatus(); + } + +-util::Status TrainerInterface::SaveVocab(absl::string_view filename) const { ++absl::Status TrainerInterface::SaveVocab(absl::string_view filename) const { + LOG(INFO) << "Saving vocabs: " << filename; + ModelProto model_proto; + RETURN_IF_ERROR(Serialize(&model_proto)); +@@ -644,7 +645,7 @@ util::Status TrainerInterface::SaveVocab(absl::string_view filename) const { + return util::OkStatus(); + } + +-util::Status TrainerInterface::Save() const { ++absl::Status TrainerInterface::Save() const { + if (output_model_proto_) { + RETURN_IF_ERROR(Serialize(output_model_proto_)); + } else { +@@ -654,7 +655,7 @@ util::Status TrainerInterface::Save() const { + return util::OkStatus(); + } + +-util::Status TrainerInterface::InitMetaPieces() { ++absl::Status TrainerInterface::InitMetaPieces() { + CHECK_OR_RETURN(meta_pieces_.empty()); + bool has_unk = false; + +diff --git a/src/trainer_interface.h b/src/trainer_interface.h +index f66d59a..b4fbc7b 100644 +--- a/src/trainer_interface.h ++++ b/src/trainer_interface.h +@@ -27,7 +27,8 @@ + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" + #include "sentencepiece_trainer.h" +-#include "third_party/absl/container/flat_hash_map.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/status/status.h" + #include "util.h" + + namespace sentencepiece { +@@ -57,7 +58,7 @@ class MultiFileSentenceIterator : public SentenceIterator { + bool done() const override; + void Next() override; + const std::string &value() const override { return value_; } +- util::Status status() const override; ++ absl::Status status() const override; + + private: + void TryRead(); +@@ -90,16 +91,16 @@ class TrainerInterface { + + // Loads sentence from `sentence_iterator` and stores the model + // to `output_model_proto`. +- virtual util::Status Train(SentenceIterator *sentence_iterator, ++ virtual absl::Status Train(SentenceIterator *sentence_iterator, + ModelProto *output_model_proto) { + sentence_iterator_ = sentence_iterator; + output_model_proto_ = output_model_proto; + return Train(); + } + +- virtual util::Status Train() { return status(); } ++ virtual absl::Status Train() { return status(); } + +- virtual util::Status status() const { return status_; } ++ virtual absl::Status status() const { return status_; } + + FRIEND_TEST(TrainerInterfaceTest, IsValidSentencePieceTest); + FRIEND_TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest); +@@ -115,7 +116,7 @@ class TrainerInterface { + + // Loads all sentences from spec.input() or SentenceIterator. + // It loads at most input_sentence_size sentences. +- util::Status LoadSentences(); ++ absl::Status LoadSentences(); + + // Splits all sentencecs by whitespaces and + // replace the |sentences_| with tokenized string. +@@ -125,7 +126,7 @@ class TrainerInterface { + void SplitSentencesByWhitespace(); + + // Save model files into spec.model_prefix(). +- util::Status Save() const; ++ absl::Status Save() const; + + // Set of characters which must be included in the final vocab. + // The value of this map stores the frequency. +@@ -152,7 +153,7 @@ class TrainerInterface { + meta_pieces_; + + // Detect errors on initialization. +- util::Status status_; ++ absl::Status status_; + + // Loads sentences from SentenceIterator if not null. + SentenceIterator *sentence_iterator_ = nullptr; +@@ -162,19 +163,19 @@ class TrainerInterface { + + private: + // Serialize final_pieces_ to |model_proto|. +- util::Status Serialize(ModelProto *model_proto) const; ++ absl::Status Serialize(ModelProto *model_proto) const; + + // Saves the best sentence split with the current model for debugging. +- util::Status SaveSplits(absl::string_view filename) const; ++ absl::Status SaveSplits(absl::string_view filename) const; + + // Saves model file. +- util::Status SaveModel(absl::string_view filename) const; ++ absl::Status SaveModel(absl::string_view filename) const; + + // Saves vocabulary file for NMT. +- util::Status SaveVocab(absl::string_view filename) const; ++ absl::Status SaveVocab(absl::string_view filename) const; + + // Initializes `meta_pieces_` from TrainerSpec. +- util::Status InitMetaPieces(); ++ absl::Status InitMetaPieces(); + + // Randomly sampled raw sentences for self-testing. + std::vector self_test_samples_; +diff --git a/src/trainer_interface_test.cc b/src/trainer_interface_test.cc +index 70a51ad..d7f3f0c 100644 +--- a/src/trainer_interface_test.cc ++++ b/src/trainer_interface_test.cc +@@ -16,8 +16,8 @@ + + #include "filesystem.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_format.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_format.h" + #include "trainer_interface.h" + #include "util.h" + +diff --git a/src/unicode_script.cc b/src/unicode_script.cc +index 583dc30..11b24dc 100644 +--- a/src/unicode_script.cc ++++ b/src/unicode_script.cc +@@ -14,7 +14,7 @@ + + #include + +-#include "third_party/absl/container/flat_hash_map.h" ++#include "absl/container/flat_hash_map.h" + #include "unicode_script.h" + #include "unicode_script_map.h" + #include "util.h" +diff --git a/src/unicode_script_map.h b/src/unicode_script_map.h +index f2e67e9..f1b8299 100644 +--- a/src/unicode_script_map.h ++++ b/src/unicode_script_map.h +@@ -14,7 +14,7 @@ + + #ifndef UNICODE_SCRIPT_DATA_H_ + #define UNICODE_SCRIPT_DATA_H_ +-#include "third_party/absl/container/flat_hash_map.h" ++#include "absl/container/flat_hash_map.h" + namespace sentencepiece { + namespace unicode_script { + namespace { +diff --git a/src/unicode_script_test.cc b/src/unicode_script_test.cc +index ab33565..e0b1c4d 100644 +--- a/src/unicode_script_test.cc ++++ b/src/unicode_script_test.cc +@@ -14,7 +14,7 @@ + + #include "common.h" + #include "testharness.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" + #include "unicode_script.h" + #include "util.h" + +diff --git a/src/unigram_model.cc b/src/unigram_model.cc +index 3b99060..9c72fb9 100644 +--- a/src/unigram_model.cc ++++ b/src/unigram_model.cc +@@ -22,9 +22,9 @@ + #include + #include + +-#include "third_party/absl/memory/memory.h" +-#include "third_party/absl/strings/str_split.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/memory/memory.h" ++#include "absl/strings/str_split.h" ++#include "absl/strings/string_view.h" + #include "unigram_model.h" + #include "util.h" + +diff --git a/src/unigram_model.h b/src/unigram_model.h +index 448e489..9062f12 100644 +--- a/src/unigram_model.h ++++ b/src/unigram_model.h +@@ -24,7 +24,7 @@ + #include "freelist.h" + #include "model_interface.h" + #include "sentencepiece_model.pb.h" +-#include "third_party/darts_clone/darts.h" ++#include "include/darts.h" + + namespace sentencepiece { + namespace unigram { +diff --git a/src/unigram_model_test.cc b/src/unigram_model_test.cc +index f93b21c..808e907 100644 +--- a/src/unigram_model_test.cc ++++ b/src/unigram_model_test.cc +@@ -22,8 +22,8 @@ + #include "sentencepiece_model.pb.h" + #include "sentencepiece_processor.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" + #include "util.h" + + namespace sentencepiece { +diff --git a/src/unigram_model_trainer.cc b/src/unigram_model_trainer.cc +index 9615040..7d16bd2 100644 +--- a/src/unigram_model_trainer.cc ++++ b/src/unigram_model_trainer.cc +@@ -25,8 +25,9 @@ + #include "normalizer.h" + #include "pretokenizer_for_training.h" + #include "sentencepiece_trainer.h" +-#include "third_party/absl/container/flat_hash_map.h" +-#include "third_party/absl/memory/memory.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/memory/memory.h" ++#include "absl/status/status.h" + #include "third_party/esaxx/esa.hxx" // Suffix array library. + #include "unicode_script.h" + #include "unigram_model_trainer.h" +@@ -463,7 +464,7 @@ TrainerModel::SentencePieces Trainer::FinalizeSentencePieces( + return Sorted(final_sentencepieces); + } + +-util::Status Trainer::Train() { ++absl::Status Trainer::Train() { + RETURN_IF_ERROR(status()); + + CHECK_EQ_OR_RETURN(TrainerSpec::UNIGRAM, trainer_spec_.model_type()); +diff --git a/src/unigram_model_trainer.h b/src/unigram_model_trainer.h +index 91fbeb4..d41967d 100644 +--- a/src/unigram_model_trainer.h ++++ b/src/unigram_model_trainer.h +@@ -21,7 +21,8 @@ + #include + + #include "sentencepiece_model.pb.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + #include "trainer_interface.h" + #include "unigram_model.h" + #include "util.h" +@@ -68,7 +69,7 @@ class Trainer : public TrainerInterface { + : TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec, + denormalizer_spec) {} + +- util::Status Train() override; ++ absl::Status Train() override; + + private: + FRIEND_TEST(TrainerTest, IsValidSentencePieceTest); +diff --git a/src/unigram_model_trainer_test.cc b/src/unigram_model_trainer_test.cc +index ffe515e..fdb25f6 100644 +--- a/src/unigram_model_trainer_test.cc ++++ b/src/unigram_model_trainer_test.cc +@@ -16,8 +16,8 @@ + #include "sentencepiece_processor.h" + #include "sentencepiece_trainer.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" + #include "unigram_model_trainer.h" + #include "util.h" + +diff --git a/src/util.h b/src/util.h +index 0d15863..7122c7c 100644 +--- a/src/util.h ++++ b/src/util.h +@@ -30,7 +30,8 @@ + + #include "common.h" + #include "sentencepiece_processor.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + + #ifdef SPM_NO_THREADLOCAL + #include +@@ -359,14 +360,14 @@ std::string StrError(int errnum); + + std::vector StrSplitAsCSV(absl::string_view text); + +-inline Status OkStatus() { return Status(); } ++inline absl::Status OkStatus() { return absl::Status(); } + + #define DECLARE_ERROR(FUNC) \ +- inline util::Status FUNC##Error(absl::string_view str) { \ +- return util::Status(StatusCode::k##FUNC, str.data()); \ ++ inline absl::Status FUNC##Error(absl::string_view str) { \ ++ return absl::Status(absl::StatusCode::k##FUNC, str.data()); \ + } \ +- inline bool Is##FUNC(const util::Status &status) { \ +- return status.code() == StatusCode::k##FUNC; \ ++ inline bool Is##FUNC(const absl::Status &status) { \ ++ return status.code() ==absl::StatusCode::k##FUNC; \ + } + + DECLARE_ERROR(Cancelled) +@@ -390,8 +391,8 @@ DECLARE_ERROR(Unauthenticated) + + class StatusBuilder { + public: +- explicit StatusBuilder(StatusCode code) : code_(code) {} +- explicit StatusBuilder(StatusCode code, int loc) : code_(code) {} ++ explicit StatusBuilder(absl::StatusCode code) : code_(code) {} ++ explicit StatusBuilder(absl::StatusCode code, int loc) : code_(code) {} + + template + StatusBuilder &operator<<(const T &value) { +@@ -399,10 +400,10 @@ class StatusBuilder { + return *this; + } + +- operator Status() const { return Status(code_, os_.str()); } ++ operator absl::Status() const { return absl::Status(code_, os_.str()); } + + private: +- StatusCode code_; ++ absl::StatusCode code_; + std::ostringstream os_; + }; + +@@ -410,7 +411,7 @@ class StatusBuilder { + if (condition) { \ + } else /* NOLINT */ \ + return ::sentencepiece::util::StatusBuilder( \ +- ::sentencepiece::util::StatusCode::kInternal) \ ++ ::absl::StatusCode::kInternal) \ + << __FILE__ << "(" << __LINE__ << ") [" << #condition << "] " + + #define CHECK_EQ_OR_RETURN(a, b) CHECK_OR_RETURN((a) == (b)) +diff --git a/src/util_test.cc b/src/util_test.cc +index 71d006f..67290dc 100644 +--- a/src/util_test.cc ++++ b/src/util_test.cc +@@ -16,7 +16,8 @@ + + #include "filesystem.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" ++#include "absl/strings/str_cat.h" ++#include "absl/status/status.h" + #include "util.h" + + namespace sentencepiece { +@@ -376,27 +377,27 @@ TEST(UtilTest, STLDeleteELementsTest) { + } + + TEST(UtilTest, StatusTest) { +- const util::Status ok; ++ const absl::Status ok; + EXPECT_TRUE(ok.ok()); +- EXPECT_EQ(util::StatusCode::kOk, ok.code()); ++ EXPECT_EQ(absl::StatusCode::kOk, ok.code()); + EXPECT_EQ(std::string(""), ok.message()); + +- const util::Status s1(util::StatusCode::kUnknown, "unknown"); +- const util::Status s2(util::StatusCode::kUnknown, std::string("unknown")); ++ const absl::Status s1(absl::StatusCode::kUnknown, "unknown"); ++ const absl::Status s2(absl::StatusCode::kUnknown, std::string("unknown")); + +- EXPECT_EQ(util::StatusCode::kUnknown, s1.code()); +- EXPECT_EQ(util::StatusCode::kUnknown, s2.code()); ++ EXPECT_EQ(absl::StatusCode::kUnknown, s1.code()); ++ EXPECT_EQ(absl::StatusCode::kUnknown, s2.code()); + EXPECT_EQ(std::string("unknown"), s1.message()); + EXPECT_EQ(std::string("unknown"), s2.message()); + + auto ok2 = util::OkStatus(); + EXPECT_TRUE(ok2.ok()); +- EXPECT_EQ(util::StatusCode::kOk, ok2.code()); ++ EXPECT_EQ(absl::StatusCode::kOk, ok2.code()); + EXPECT_EQ(std::string(""), ok2.message()); + + util::OkStatus().IgnoreError(); + for (int i = 1; i <= 16; ++i) { +- util::Status s(static_cast(i), "message"); ++ absl::Status s(static_cast(i), "message"); + EXPECT_TRUE(s.ToString().find("message") != std::string::npos) + << s.ToString(); + } +diff --git a/src/word_model_trainer.cc b/src/word_model_trainer.cc +index 0b8b062..bc1f86b 100644 +--- a/src/word_model_trainer.cc ++++ b/src/word_model_trainer.cc +@@ -15,8 +15,9 @@ + #include + #include + +-#include "third_party/absl/container/flat_hash_map.h" +-#include "third_party/absl/strings/string_view.h" ++#include "absl/container/flat_hash_map.h" ++#include "absl/strings/string_view.h" ++#include "absl/status/status.h" + #include "util.h" + #include "word_model.h" + #include "word_model_trainer.h" +@@ -24,7 +25,7 @@ + namespace sentencepiece { + namespace word { + +-util::Status Trainer::Train() { ++absl::Status Trainer::Train() { + RETURN_IF_ERROR(status()); + + CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces()); +diff --git a/src/word_model_trainer.h b/src/word_model_trainer.h +index 76f8f32..436e595 100644 +--- a/src/word_model_trainer.h ++++ b/src/word_model_trainer.h +@@ -17,6 +17,7 @@ + + #include "sentencepiece_model.pb.h" + #include "trainer_interface.h" ++#include "absl/status/status.h" + + namespace sentencepiece { + namespace word { +@@ -34,7 +35,7 @@ class Trainer : public TrainerInterface { + : TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec, + denormalizer_spec) {} + +- util::Status Train() override; ++ absl::Status Train() override; + }; + } // namespace word + } // namespace sentencepiece +diff --git a/src/word_model_trainer_test.cc b/src/word_model_trainer_test.cc +index c4a8bc6..366810f 100644 +--- a/src/word_model_trainer_test.cc ++++ b/src/word_model_trainer_test.cc +@@ -18,8 +18,8 @@ + #include "filesystem.h" + #include "sentencepiece_processor.h" + #include "testharness.h" +-#include "third_party/absl/strings/str_cat.h" +-#include "third_party/absl/strings/str_join.h" ++#include "absl/strings/str_cat.h" ++#include "absl/strings/str_join.h" + #include "util.h" + #include "word_model_trainer.h" + diff --git a/patches/darts_clone.BUILD b/patches/darts_clone.BUILD new file mode 100644 index 000000000..3ce02f045 --- /dev/null +++ b/patches/darts_clone.BUILD @@ -0,0 +1,12 @@ +licenses(["notice"]) + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "darts_clone", + hdrs = [ + "include/darts.h", + ], +) diff --git a/patches/darts_no_exceptions.diff b/patches/darts_no_exceptions.diff new file mode 100644 index 000000000..d1aadd24c --- /dev/null +++ b/patches/darts_no_exceptions.diff @@ -0,0 +1,87 @@ +--- include/darts.h ++++ include/darts.h +@@ -14,10 +14,14 @@ + // what() as well as that of . + #define DARTS_INT_TO_STR(value) #value + #define DARTS_LINE_TO_STR(line) DARTS_INT_TO_STR(line) + #define DARTS_LINE_STR DARTS_LINE_TO_STR(__LINE__) +-#define DARTS_THROW(msg) throw Darts::Details::Exception( \ +- __FILE__ ":" DARTS_LINE_STR ": exception: " msg) ++#include ++#include ++#define DARTS_THROW(msg) do { \ ++ std::fprintf(stderr, "Darts error: %s (%s:%d)\n", msg, __FILE__, __LINE__); \ ++ std::abort(); \ ++} while (0) + + namespace Darts { + + // The following namespace hides the internal types and classes. +@@ -85,17 +89,9 @@ + // constant or static string because an keeps only a pointer to + // that string. + class Exception : public std::exception { + public: +- explicit Exception(const char *msg = NULL) throw() : msg_(msg) {} +- Exception(const Exception &rhs) throw() : msg_(rhs.msg_) {} +- virtual ~Exception() throw() {} + +- // overrides what() of . +- virtual const char *what() const throw() { +- return (msg_ != NULL) ? msg_ : ""; +- } +- + private: + const char *msg_; + + // Disallows operator=. +@@ -375,16 +371,11 @@ + } + } + + unit_type *buf; +- try { +- buf = new unit_type[size]; +- for (id_type i = 0; i < 256; ++i) { +- buf[i] = units[i]; +- } +- } catch (const std::bad_alloc &) { +- std::fclose(file); +- DARTS_THROW("failed to open double-array: std::bad_alloc"); ++ buf = new unit_type[size]; ++ for (id_type i = 0; i < 256; ++i) { ++ buf[i] = units[i]; + } + + if (size > 256) { + if (std::fread(buf + 256, unit_size(), size - 256, file) != size - 256) { +@@ -701,13 +692,9 @@ + } + } + + AutoArray buf; +- try { +- buf.reset(new char[sizeof(T) * capacity]); +- } catch (const std::bad_alloc &) { +- DARTS_THROW("failed to resize pool: std::bad_alloc"); +- } ++ buf.reset(new char[sizeof(T) * capacity]); + + if (size_ > 0) { + T *src = reinterpret_cast(&buf_[0]); + T *dest = reinterpret_cast(&buf[0]); +@@ -840,13 +827,9 @@ + } + }; + + inline void BitVector::build() { +- try { +- ranks_.reset(new id_type[units_.size()]); +- } catch (const std::bad_alloc &) { +- DARTS_THROW("failed to build rank index: std::bad_alloc"); +- } ++ ranks_.reset(new id_type[units_.size()]); + + num_ones_ = 0; + for (std::size_t i = 0; i < units_.size(); ++i) { + ranks_[i] = num_ones_; diff --git a/patches/ndk_25_r14.diff b/patches/ndk_25_r14.diff deleted file mode 100644 index f3043aa92..000000000 --- a/patches/ndk_25_r14.diff +++ /dev/null @@ -1,184 +0,0 @@ -diff --git a/configure.py b/configure.py -index 262637734a5..b6eb015463b 100644 ---- a/configure.py -+++ b/configure.py -@@ -17,6 +17,7 @@ - import argparse - import errno - import glob -+import json - import os - import platform - import re -@@ -36,7 +37,7 @@ _DEFAULT_TENSORRT_VERSION = '6' - _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' - - _SUPPORTED_ANDROID_NDK_VERSIONS = [ -- 19, 20, 21 -+ 19, 20, 21, 25 - ] - - _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 -@@ -744,20 +745,16 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % - (android_ndk_home_path, ndk_version, _SUPPORTED_ANDROID_NDK_VERSIONS)) -+ write_action_env_to_bazelrc('ANDROID_NDK_VERSION', ndk_version) - - # Now grab the NDK API level to use. Note that this is different from the - # SDK API level, as the NDK API level is effectively the *min* target SDK - # version. -- platforms = os.path.join(android_ndk_home_path, 'platforms') -- api_levels = sorted(os.listdir(platforms)) -- api_levels = [ -- x.replace('android-', '') for x in api_levels if 'android-' in x -- ] -- -- def valid_api_level(api_level): -- return os.path.exists( -- os.path.join(android_ndk_home_path, 'platforms', 'android-' + api_level) -- ) -+ meta = open(os.path.join(android_ndk_home_path, 'meta/platforms.json')) -+ platforms = json.load(meta) -+ meta.close -+ aliases = platforms['aliases'] -+ api_levels = sorted(list(set([ aliases[i] for i in aliases ]))) - - android_ndk_api_level = prompt_loop_or_load_from_env( - environ_cp, -@@ -768,7 +765,7 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): - '[Available levels: %s]' - ) - % api_levels, -- check_success=valid_api_level, -+ check_success=(lambda *_: True), - error_msg='Android-%s is not present in the NDK path.', - ) - -diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl -index 7e9faa558a4..3ea52dc3948 100644 ---- a/tensorflow/workspace2.bzl -+++ b/tensorflow/workspace2.bzl -@@ -816,6 +816,13 @@ def _tf_repositories(): - urls = tf_mirror_urls("https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"), - ) - -+ tf_http_archive( -+ name = "rules_android_ndk", -+ sha256 = "b29409496439cdcdb50a8e161c4953ca78a548e16d3ee729a1b5cd719ffdacbf", -+ strip_prefix = "rules_android_ndk-81ec8b79dc50ee97e336a25724fdbb28e33b8d41", -+ urls = tf_mirror_urls("https://github.com/bazelbuild/rules_android_ndk/archive/81ec8b79dc50ee97e336a25724fdbb28e33b8d41.zip"), -+ ) -+ - # Apple and Swift rules. - # https://github.com/bazelbuild/rules_apple/releases - tf_http_archive( -diff --git a/third_party/android/android.bzl.tpl b/third_party/android/android.bzl.tpl -index e6ed4994f3b..802873f9cb4 100644 ---- a/third_party/android/android.bzl.tpl -+++ b/third_party/android/android.bzl.tpl -@@ -1,3 +1,5 @@ -+MAYBE_ANDROID_NDK_STARLARK_RULES -+ - """Set up configurable Android SDK and NDK dependencies.""" - - def android_workspace(): -diff --git a/third_party/android/android_configure.bzl b/third_party/android/android_configure.bzl -index 2b364118073..bd1a1933172 100644 ---- a/third_party/android/android_configure.bzl -+++ b/third_party/android/android_configure.bzl -@@ -14,8 +14,9 @@ - - _ANDROID_NDK_HOME = "ANDROID_NDK_HOME" - _ANDROID_SDK_HOME = "ANDROID_SDK_HOME" --_ANDROID_NDK_API_VERSION = "ANDROID_NDK_API_LEVEL" --_ANDROID_SDK_API_VERSION = "ANDROID_SDK_API_LEVEL" -+_ANDROID_NDK_VERSION = "ANDROID_NDK_VERSION" -+_ANDROID_NDK_API_LEVEL = "ANDROID_NDK_API_LEVEL" -+_ANDROID_SDK_API_LEVEL = "ANDROID_SDK_API_LEVEL" - _ANDROID_BUILD_TOOLS_VERSION = "ANDROID_BUILD_TOOLS_VERSION" - - _ANDROID_SDK_REPO_TEMPLATE = """ -@@ -27,7 +28,7 @@ _ANDROID_SDK_REPO_TEMPLATE = """ - ) - """ - --_ANDROID_NDK_REPO_TEMPLATE = """ -+_ANDROID_NDK_REPO_TEMPLATE_INTERNAL = """ - native.android_ndk_repository( - name="androidndk", - path="%s", -@@ -35,15 +36,36 @@ _ANDROID_NDK_REPO_TEMPLATE = """ - ) - """ - -+_ANDROID_NDK_REPO_TEMPLATE_STARLARK = """ -+ android_ndk_repository( -+ name="androidndk", -+ path="%s", -+ api_level=%s, -+ ) -+ -+ # Bind android/crosstool to support legacy select() -+ # https://github.com/bazelbuild/rules_android_ndk/issues/31#issuecomment-1396182185 -+ native.bind( -+ name = "android/crosstool", -+ actual = "@androidndk//:toolchain", -+ ) -+""" -+ -+# Import NDK Starlark rules. Shouldn't have any indentation. -+_ANDROID_NDK_STARLARK_RULES = """ -+load("@rules_android_ndk//:rules.bzl", "android_ndk_repository") -+""" -+ - def _android_autoconf_impl(repository_ctx): - """Implementation of the android_autoconf repository rule.""" - sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME) -- sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION) -+ sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_LEVEL) - build_tools_version = repository_ctx.os.environ.get( - _ANDROID_BUILD_TOOLS_VERSION, - ) - ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME) -- ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION) -+ ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_LEVEL) -+ ndk_version = int(repository_ctx.os.environ.get(_ANDROID_NDK_VERSION)) - - sdk_rule = "" - if all([sdk_home, sdk_api_level, build_tools_version]): -@@ -54,8 +76,13 @@ def _android_autoconf_impl(repository_ctx): - ) - - ndk_rule = "" -+ ndk_starlark_rules = "" - if all([ndk_home, ndk_api_level]): -- ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level) -+ if ndk_version >= 25: -+ ndk_starlark_rules = _ANDROID_NDK_STARLARK_RULES -+ ndk_rule = _ANDROID_NDK_REPO_TEMPLATE_STARLARK % (ndk_home, ndk_api_level) -+ else: -+ ndk_rule = _ANDROID_NDK_REPO_TEMPLATE_INTERNAL % (ndk_home, ndk_api_level) - - if ndk_rule == "" and sdk_rule == "": - sdk_rule = "pass" -@@ -68,6 +95,7 @@ def _android_autoconf_impl(repository_ctx): - "android.bzl", - Label("//third_party/android:android.bzl.tpl"), - substitutions = { -+ "MAYBE_ANDROID_NDK_STARLARK_RULES": ndk_starlark_rules, - "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule, - "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule, - }, -@@ -76,8 +104,9 @@ def _android_autoconf_impl(repository_ctx): - android_configure = repository_rule( - implementation = _android_autoconf_impl, - environ = [ -- _ANDROID_SDK_API_VERSION, -- _ANDROID_NDK_API_VERSION, -+ _ANDROID_SDK_API_LEVEL, -+ _ANDROID_NDK_VERSION, -+ _ANDROID_NDK_API_LEVEL, - _ANDROID_BUILD_TOOLS_VERSION, - _ANDROID_NDK_HOME, - _ANDROID_SDK_HOME, diff --git a/patches/sentencepiece.BUILD b/patches/sentencepiece.BUILD new file mode 100644 index 000000000..8e46b1376 --- /dev/null +++ b/patches/sentencepiece.BUILD @@ -0,0 +1,165 @@ +package( + default_visibility = ["//visibility:public"], + features = [ + "layering_check", + "parse_headers", + ], +) + +licenses(["notice"]) + +proto_library( + name = "sentencepiece_proto", + srcs = ["src/sentencepiece.proto"], +) + +cc_proto_library( + name = "sentencepiece_cc_proto", + deps = [":sentencepiece_proto"], +) + +proto_library( + name = "sentencepiece_model_proto", + srcs = ["src/sentencepiece_model.proto"], +) + +cc_proto_library( + name = "sentencepiece_model_cc_proto", + deps = [":sentencepiece_model_proto"], +) + +genrule( + name = "config_h", + srcs = ["config.h.in"], + outs = ["config.h"], + cmd = "cp $< $@", +) + +cc_library( + name = "common", + hdrs = [ + "config.h", + "src/common.h", + ], + deps = [ + "@com_google_absl//absl/base", + ], +) + +cc_library( + name = "sentencepiece_processor", + srcs = [ + "src/bpe_model.cc", + "src/char_model.cc", + "src/error.cc", + "src/filesystem.cc", + "src/model_factory.cc", + "src/model_interface.cc", + "src/normalizer.cc", + "src/sentencepiece_processor.cc", + "src/unigram_model.cc", + "src/util.cc", + "src/word_model.cc", + ], + hdrs = [ + "src/bpe_model.h", + "src/char_model.h", + "src/filesystem.h", + "src/freelist.h", + "src/model_factory.h", + "src/model_interface.h", + "src/normalizer.h", + "src/sentencepiece_processor.h", + "src/trainer_interface.h", + "src/unigram_model.h", + "src/util.h", + "src/word_model.h", + ], + defines = ["_USE_TF_STRING_VIEW"], + includes = [ + ".", + "src", + ], + linkstatic = 1, + deps = + [ + ":common", + ":sentencepiece_cc_proto", + ":sentencepiece_model_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@darts_clone", + ], +) + +cc_library( + name = "sentencepiece_trainer", + srcs = [ + "src/bpe_model_trainer.cc", + "src/builder.cc", + "src/char_model_trainer.cc", + "src/sentencepiece_trainer.cc", + "src/trainer_factory.cc", + "src/trainer_interface.cc", + "src/unicode_script.cc", + "src/unigram_model_trainer.cc", + "src/word_model_trainer.cc", + ], + hdrs = [ + "src/bpe_model_trainer.h", + "src/builder.h", + "src/char_model_trainer.h", + "src/normalization_rule.h", + "src/sentencepiece_trainer.h", + "src/spec_parser.h", + "src/trainer_factory.h", + "src/trainer_interface.h", + "src/unicode_script.h", + "src/unicode_script_map.h", + "src/unigram_model_trainer.h", + "src/word_model_trainer.h", + "third_party/esaxx/esa.hxx", + "third_party/esaxx/sais.hxx", + ], + includes = [ + ".", + "src", + "third_party/esaxx", + ], + deps = [ + ":common", + ":pretokenizer_for_training", + ":sentencepiece_cc_proto", + ":sentencepiece_model_cc_proto", + ":sentencepiece_processor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@darts_clone", + ], +) + +cc_library( + name = "pretokenizer_for_training", + srcs = ["src/pretokenizer_for_training.cc"], + hdrs = ["src/pretokenizer_for_training.h"], + includes = [ + ".", + "src", + ], + deps = [ + ":common", + ":sentencepiece_cc_proto", + ":sentencepiece_processor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) From 5aab20a6cc2c20c8b415b4ac19daa8efa20f897a Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 5 Aug 2025 02:37:28 +0300 Subject: [PATCH 02/74] fixed issues preventing libraries from compiling, runtime errors not included --- mobile_back_tflite/cpp/backend_tflite/BUILD | 3 +- .../cpp/backend_tflite/llm_pipeline.cc | 98 +++++++++++-------- .../cpp/backend_tflite/llm_pipeline.h | 17 +++- 3 files changed, 73 insertions(+), 45 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/BUILD b/mobile_back_tflite/cpp/backend_tflite/BUILD index 41d4cb07a..c93e92471 100644 --- a/mobile_back_tflite/cpp/backend_tflite/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/BUILD @@ -92,7 +92,8 @@ cc_library( "@org_tensorflow//tensorflow/lite/c:common", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:util", -# "@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops", ] + select({ "@org_tensorflow//tensorflow:android": [ "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index f802bf84f..06f739206 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #if defined(MTK_TFLITE_NEURON_BACKEND) && defined(__ANDROID__) #include @@ -25,17 +26,10 @@ limitations under the License. #include "flutter/cpp/c/type.h" #include "flutter/cpp/utils.h" -#include "tensorflow/lite/c/c_api.h" #include "tensorflow/lite/c/common.h" #if __ANDROID__ #include -#if MTK_TFLITE_NEURON_BACKEND -#include "neuron/neuron_backend.h" -#include "neuron/neuron_builder.h" -#include "neuron/neuron_delegate.h" -#endif - #include "tensorflow/lite/delegates/gpu/delegate.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #endif @@ -47,19 +41,6 @@ extern "C" { static bool backendExists = false; -#if __ANDROID__ -bool is_emulator() { - char ro_build_characteristics[PROP_VALUE_MAX + 1]; - if (__system_property_get("ro.build.characteristics", - ro_build_characteristics)) { - char *ptr; - ptr = strstr(ro_build_characteristics, "emulator"); - if (ptr) return true; - } - return false; -} -#endif - // Destroy the backend pointer and its data. void LLMPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; @@ -262,23 +243,20 @@ void LLMPipeline::backend_release_buffer(void *p) { } TfLiteInterpreter *LLMPipeline::BuildInterpreter(TfLiteModel *model, int num_threads) { - tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetNumThreads(options, num_threads); // NOTE: We need to manually register optimized OPs for KV-cache and // Scaled Dot Product Attention (SDPA). - tflite::ops::custom::GenAIOpsRegisterer(&resolver); - tflite::InterpreterBuilder builder(*model, resolver); - //TODO - MINIMAL_CHECK(builder.SetNumThreads(num_threads) == kTfLiteOk); - TfLiteInterpreter *interpreter; - builder(&interpreter); - //TODO - MINIMAL_CHECK(interpreter != nullptr); + TfLiteInterpreterOptionsAddCustomOp(options, "GEN_AI_GENERATE", GetGenAIGenerateOp(), 1, 1); + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + + MINIMAL_CHECK_PTR(interpreter != nullptr); return interpreter; } kv_cache_t LLMPipeline::BuildKVCache(TfLiteInterpreter *interpreter) { - TfLiteSignatureRunner *runner = interpreter->GetSignatureRunner("decode"); + TfLiteSignatureRunner *runner = TfLiteInterpreterGetSignatureRunner(interpreter, "decode"); // TODO if (runner == nullptr) { return {}; @@ -306,7 +284,7 @@ kv_cache_t LLMPipeline::BuildKVCache(TfLiteInterpreter *interpreter) { return kv_cache; } -void LLMPipeline::PrepareRunner(tflite::SignatureRunner* runner, kv_cache_t& kv_cache) { +void LLMPipeline::PrepareRunner(TfLiteSignatureRunner* runner, kv_cache_t& kv_cache) { for (auto& [name, cache] : kv_cache) { TfLiteCustomAllocation allocation = { .data = static_cast(cache.data()), @@ -314,13 +292,10 @@ void LLMPipeline::PrepareRunner(tflite::SignatureRunner* runner, kv_cache_t& kv_ // Both input and output tensors are set to the same buffer. Not all // delegates support this in-place update. For those cases, we need to do // a ping-pong buffer and update the pointers between inference calls. - //TODO - MINIMAL_CHECK(runner->SetCustomAllocationForInputTensor(name.c_str(), allocation) == kTfLiteOk); - //TODO - MINIMAL_CHECK(runner->SetCustomAllocationForOutputTensor(name.c_str(), allocation) == kTfLiteOk); + MINIMAL_CHECK_VOID(TfLiteSignatureRunnerSetInputCustomAllocation(runner, name.c_str(), &allocation)); + MINIMAL_CHECK_VOID(TfLiteSignatureRunnerSetInputCustomAllocation(runner, name.c_str(), &allocation)); } - //TODO - MINIMAL_CHECK(runner->AllocateTensors() == kTfLiteOk); + MINIMAL_CHECK_VOID(TfLiteSignatureRunnerAllocateTensors(runner)); } TfLiteSignatureRunner *LLMPipeline::GetPrefillRunner(TfLiteInterpreter* interpreter, std::size_t num_input_tokens, kv_cache_t& kv_cache) { @@ -335,20 +310,20 @@ TfLiteSignatureRunner *LLMPipeline::GetPrefillRunner(TfLiteInterpreter* interpre // The expected shape for input position is [Seq]. size_t seq_size = input_pos->dims->data[0]; if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) { - runner = TfLiteInterpreterGetSignatureRunner(interpreter, key->c_str()); + runner = TfLiteInterpreterGetSignatureRunner(interpreter, key.c_str()); //best_seq_size = seq_size; delta = seq_size - num_input_tokens; } } - MINIMAL_CHECK(runner != nullptr); - PrepareRunner(runner->impl, kv_cache); + MINIMAL_CHECK_PTR(runner != nullptr); + PrepareRunner(runner, kv_cache); return runner; } TfLiteSignatureRunner *LLMPipeline::GetDecodeRunner(TfLiteInterpreter* interpreter, kv_cache_t& kv_cache) { TfLiteSignatureRunner* runner = TfLiteInterpreterGetSignatureRunner(interpreter, "decode"); - MINIMAL_CHECK(runner != nullptr); - PrepareRunner(runner->impl, kv_cache); + MINIMAL_CHECK_PTR(runner != nullptr); + PrepareRunner(runner, kv_cache); return runner; } @@ -357,7 +332,7 @@ sentencepiece::SentencePieceProcessor *LLMPipeline::LoadSentencePieceProcessor(s std::string serialized_proto = std::string( std::istreambuf_iterator(input), std::istreambuf_iterator()); auto processor = new sentencepiece::SentencePieceProcessor(); - MINIMAL_CHECK(processor->LoadFromSerializedProto(serialized_proto).ok()); + MINIMAL_CHECK_PTR(processor->LoadFromSerializedProto(serialized_proto).ok()); return processor; } @@ -375,6 +350,43 @@ int LLMPipeline::GreedySampler(const TfLiteTensor* logits) { return max_index; } +TfLiteRegistration* LLMPipeline::GetGenAIGenerateOp() { + static tflite::MutableOpResolver resolver; + + tflite::ops::custom::GenAIOpsRegisterer(&resolver); + + const TfLiteRegistration* reg = resolver.FindOp("GEN_AI_GENERATE", /*version=*/1); + + if (!reg) { + LOG(ERROR) << "Could not find GEN_AI_GENERATE op." << std::endl; + return nullptr; + } + + static TfLiteRegistration reg_copy = *reg; + return ®_copy; +} + +bool LLMPipeline::TfLiteSignatureRunnerSetInputCustomAllocation(TfLiteSignatureRunner* runner, const char* input_name, const TfLiteCustomAllocation* allocation) { + if (!runner || !input_name || !allocation) return false; + auto* cpp_runner = reinterpret_cast(runner); + return cpp_runner + ->SetCustomAllocationForInputTensor(input_name, *allocation) == + kTfLiteOk; +} + +bool LLMPipeline::TfLiteSignatureRunnerSetOutputCustomAllocation(TfLiteSignatureRunner* runner, const char* output_name, const TfLiteCustomAllocation* allocation) { + if (!runner || !output_name || !allocation) return false; + auto* cpp_runner = reinterpret_cast(runner); + return cpp_runner + ->SetCustomAllocationForOutputTensor(output_name, *allocation) == + kTfLiteOk; +} + +bool LLMPipeline::TfLiteSignatureRunnerAllocateTensors(TfLiteSignatureRunner* runner) { + if (!runner) return false; + return reinterpret_cast(runner)->AllocateTensors() == kTfLiteOk; +} + #ifdef __cplusplus }; #endif // __cplusplus diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index cc24e9808..315e9ad6d 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -37,6 +37,16 @@ if (!(x)) { \ LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ return MLPERF_FAILURE; \ } +#define MINIMAL_CHECK_PTR(x) \ +if (!(x)) { \ + LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return nullptr; \ +} +#define MINIMAL_CHECK_VOID(x) \ +if (!(x)) { \ + LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return; \ +} // TF Lite requires all buffers (including external buffers used for KV cache // here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use @@ -149,11 +159,16 @@ class LLMPipeline : public Pipeline { private: TfLiteInterpreter *BuildInterpreter(TfLiteModel *model, int num_threads); kv_cache_t BuildKVCache(TfLiteInterpreter *interpreter); - void PrepareRunner(tflite::SignatureRunner *runner, kv_cache_t &kv_cache); + void PrepareRunner(TfLiteSignatureRunner *runner, kv_cache_t &kv_cache); TfLiteSignatureRunner *GetPrefillRunner(TfLiteInterpreter *interpreter, std::size_t num_input_tokens, kv_cache_t &kv_cache); TfLiteSignatureRunner *GetDecodeRunner(TfLiteInterpreter *interpreter, kv_cache_t &kv_cache); sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor(std::string path); int GreedySampler(const TfLiteTensor *logits); + TfLiteRegistration* GetGenAIGenerateOp(); + bool TfLiteSignatureRunnerSetInputCustomAllocation(struct TfLiteSignatureRunner* runner, const char* input_name, const struct TfLiteCustomAllocation* allocation); + bool TfLiteSignatureRunnerSetOutputCustomAllocation(struct TfLiteSignatureRunner* runner, const char* output_name, const struct TfLiteCustomAllocation* allocation); + bool TfLiteSignatureRunnerAllocateTensors(TfLiteSignatureRunner* runner); + }; #endif // TFLITE_SINGLE_MODEL_PIPELINE_H_ From f598e57dae10c7e46f1563bdeabee0930ab9af99 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 19 Aug 2025 00:19:14 +0300 Subject: [PATCH 03/74] upgrade TensorFlow to 2.18.0 --- .bazelrc | 3 +++ WORKSPACE | 12 +++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index a2ee67b8c..ccbb17890 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,6 +10,9 @@ build --spawn_strategy=standalone # This flag is required by tensorflow common --experimental_repo_remote_exec +common --repo_env=TF_NEED_CUDA=0 +common --repo_env=TF_NEED_ROCM=0 + # Default options should come above this line. # Configure logs diff --git a/WORKSPACE b/WORKSPACE index cc6b12164..7ab14ceef 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -40,13 +40,19 @@ http_archive( "//:flutter/third_party/tensorflow-fix-file-opening-mode-for-Windows.patch", #"//:flutter/third_party/tf-eigen.patch", ] + PATCH_FILE, - sha256 = "9cc4d5773b8ee910079baaecb4086d0c28939f024dd74b33fc5e64779b6533dc", - strip_prefix = "tensorflow-2.17.0", + sha256 = "d7876f4bb0235cac60eb6316392a7c48676729860da1ab659fb440379ad5186d", + strip_prefix = "tensorflow-2.18.0", urls = [ - "https://github.com/tensorflow/tensorflow/archive/v2.17.0.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/v2.18.0.tar.gz", ], ) +load("@org_tensorflow//third_party/gpus:cuda_configure.bzl", "cuda_configure") +cuda_configure(name = "local_config_cuda") + +load("@org_tensorflow//third_party/gpus:rocm_configure.bzl", "rocm_configure") +rocm_configure(name = "local_config_rocm") + http_archive( name = "com_google_sentencepiece", strip_prefix = "sentencepiece-0.1.96", From fe32950d33e26348e28097d3fe928e690a76c6d3 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 19 Aug 2025 00:29:33 +0300 Subject: [PATCH 04/74] upgraded llm pipeline to use TFLite C++ api + small bug fixes --- flutter/cpp/binary/main.cc | 2 +- flutter/cpp/datasets/mmlu_gen.cc | 4 +- .../datasets/mmlu_utils/generate_tfrecords.py | 10 +- .../tflite_settings_android.pbtxt | 2 +- .../cpp/backend_tflite/embedding_utils.h | 3 +- .../cpp/backend_tflite/llm_pipeline.cc | 116 ++++++------------ .../cpp/backend_tflite/llm_pipeline.h | 40 +++--- .../cpp/backend_tflite/tflite_c.cc | 11 +- 8 files changed, 79 insertions(+), 109 deletions(-) diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index aadccf855..df854627b 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -118,7 +118,7 @@ int Main(int argc, char *argv[]) { "Benchmark ID. One of image_classification, " "image_classification_v2, object_detection, " "natural_language_processing, " - "image_segmentation_v2, super_resolution, stable_diffusion, LLM, " + "image_segmentation_v2, super_resolution, stable_diffusion, llm, " "image_classification_offline, image_classification_offline_v2", Flag::kPositional)}; Flags::Parse(&argc, const_cast(argv), flag_list); diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 65a39717f..22933bd14 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -15,8 +15,8 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) tensorflow::tstring record = sample_reader_.ReadRecord(i); tensorflow::Example example; example.ParseFromString(record); - std::string input = GetFeatureValues("input", example).Get(0); - std::string answer = GetFeatureValues("answer", example).Get(0); + std::string input = tensorflow::GetFeatureValues("input", example).Get(0); + std::string answer = tensorflow::GetFeatureValues("answer", example).Get(0); auto sample = std::make_unique(); sample->input = input; diff --git a/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py b/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py index a87a389da..793a513d0 100644 --- a/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py +++ b/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py @@ -4,8 +4,8 @@ def parse_args(): parser = argparse.ArgumentParser(description="Convert a CSV of LLM prompts to TFRecord format.") - parser.add_argument('--input', type=str, required=True, help="Path to the input CSV file.") - parser.add_argument('--output', type=str, required=True, help="Path to the output TFRecord file.") + parser.add_argument('--input_file', type=str, required=True, help="Path to the input CSV file.") + parser.add_argument('--output_file', type=str, required=True, help="Path to the output TFRecord file.") return parser.parse_args() def map_answer(num): @@ -19,19 +19,19 @@ def create_example(input_text, answer_letter): def main(): args = parse_args() - df = pd.read_csv(args.input_csv) + df = pd.read_csv(args.input_file) if "input_formatted" not in df.columns or "answer" not in df.columns: raise ValueError("CSV must contain 'input_formatted' and 'answer' columns.") df["answer_letter"] = df["answer"].map(map_answer) - with tf.io.TFRecordWriter(args.output_tfrecord) as writer: + with tf.io.TFRecordWriter(args.output_file) as writer: for _, row in df.iterrows(): example = create_example(row["input_formatted"], row["answer_letter"]) writer.write(example.SerializeToString()) - print(f"TFRecord written to: {args.output_tfrecord}") + print(f"TFRecord written to: {args.output_file}") if __name__ == "__main__": main() diff --git a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt index 11eb346bf..5df21de7d 100644 --- a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt +++ b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt @@ -277,6 +277,6 @@ benchmark_setting { } custom_setting { id: "sentencepiece_processor_path" - value: "llama.model.sp" + value: "llama3_1b.spm.model" } } diff --git a/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h b/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h index f543c6332..74951ed1b 100644 --- a/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h +++ b/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h @@ -6,6 +6,7 @@ #include #include #include +#include class TsEmbeddingParser { public: @@ -37,4 +38,4 @@ class EmbeddingManager { std::unique_ptr ts_parser_; }; -#endif // EMBEDDING_UTILS_H_ \ No newline at end of file +#endif // EMBEDDING_UTILS_H_ diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 06f739206..e312f28a7 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -44,14 +44,7 @@ static bool backendExists = false; // Destroy the backend pointer and its data. void LLMPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - if (backend_data) { - TfLiteModelDelete(backend_data->model); - TfLiteSignatureRunnerDelete(backend_data->prefill_runner); - TfLiteSignatureRunnerDelete(backend_data->decode_runner); - TfLiteInterpreterDelete(backend_data->interpreter); - delete backend_data->sp_processor; - delete backend_data; - } + if (backend_data) delete backend_data; backendExists = false; } @@ -70,7 +63,7 @@ mlperf_backend_ptr_t LLMPipeline::backend_create(const char *model_path, mlperf_ std::string sppp = mlperf::mobile::GetConfigValue(configs, "sentencepiece_processor_path", std::string("")); // Load the model. - backend_data->model = TfLiteModelCreateFromFile(model_path); + backend_data->model = tflite::FlatBufferModel::BuildFromFile(model_path).release(); if (!backend_data->model) { LOG(ERROR) << "Failed to load model: " << model_path; backend_delete(backend_data); @@ -123,15 +116,15 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt // Get Input Tensors for each of the runners. // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* prefill_input = TfLiteSignatureRunnerGetInputTensor(backend_data->prefill_runner, "tokens"); + TfLiteTensor* prefill_input = backend_data->prefill_runner->input_tensor("tokens"); // Shape: [Seq], Dtype: int32 - TfLiteTensor* prefill_input_pos = TfLiteSignatureRunnerGetInputTensor(backend_data->prefill_runner, "input_pos"); + TfLiteTensor* prefill_input_pos = backend_data->prefill_runner->input_tensor("input_pos"); // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* decode_input = TfLiteSignatureRunnerGetInputTensor(backend_data->decode_runner, "tokens"); + TfLiteTensor* decode_input = backend_data->decode_runner->input_tensor("tokens"); // Shape: [Seq], Dtype: int32 - TfLiteTensor* decode_input_pos = TfLiteSignatureRunnerGetInputTensor(backend_data->decode_runner, "input_pos"); + TfLiteTensor* decode_input_pos = backend_data->decode_runner->input_tensor("input_pos"); // shape: [Batch, kv_cache_max, num_query_groups, head_dim] - TfLiteTensor* kv_cache_k_0 = TfLiteSignatureRunnerGetInputTensor(backend_data->decode_runner, "kv_cache_k_0"); + TfLiteTensor* kv_cache_k_0 = backend_data->decode_runner->input_tensor("kv_cache_k_0"); int max_seq_size = prefill_input->dims->data[1]; int kv_cache_max_size = kv_cache_k_0->dims->data[1]; @@ -144,7 +137,7 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt prefill_input_pos->data.i32[i] = i; } - MINIMAL_CHECK(TfLiteSignatureRunnerInvoke(backend_data->prefill_runner) == kTfLiteOk); + MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); int decode_steps = kv_cache_max_size - prefill_seq_size; MINIMAL_CHECK(decode_steps > 0); @@ -156,8 +149,8 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt for (int i = 0; i < decode_steps; ++i) { decode_input->data.i32[0] = next_token; decode_input_pos->data.i32[0] = next_position; - MINIMAL_CHECK(TfLiteSignatureRunnerInvoke(backend_data->decode_runner) == kTfLiteOk); - next_token = GreedySampler(TfLiteSignatureRunnerGetOutputTensor(backend_data->decode_runner, "logits")); + MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); + next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); output_tokens.push_back(next_token); next_position += 1; if (next_token == backend_data->stop_token_id) break; @@ -242,28 +235,29 @@ void LLMPipeline::backend_release_buffer(void *p) { ::operator delete(p); } -TfLiteInterpreter *LLMPipeline::BuildInterpreter(TfLiteModel *model, int num_threads) { - TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); - TfLiteInterpreterOptionsSetNumThreads(options, num_threads); +tflite::Interpreter *LLMPipeline::BuildInterpreter(tflite::FlatBufferModel *model, int num_threads) { + tflite::ops::builtin::BuiltinOpResolver resolver; // NOTE: We need to manually register optimized OPs for KV-cache and // Scaled Dot Product Attention (SDPA). - TfLiteInterpreterOptionsAddCustomOp(options, "GEN_AI_GENERATE", GetGenAIGenerateOp(), 1, 1); - TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + tflite::ops::custom::GenAIOpsRegisterer(&resolver); + tflite::InterpreterBuilder builder(*model, resolver); + MINIMAL_CHECK_PTR(builder.SetNumThreads(num_threads) == kTfLiteOk); + std::unique_ptr interpreter; + builder(&interpreter); MINIMAL_CHECK_PTR(interpreter != nullptr); - return interpreter; + return interpreter.release(); } -kv_cache_t LLMPipeline::BuildKVCache(TfLiteInterpreter *interpreter) { - TfLiteSignatureRunner *runner = TfLiteInterpreterGetSignatureRunner(interpreter, "decode"); - // TODO +kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter* interpreter) { + tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); if (runner == nullptr) { return {}; } // The two arguments excluded are `tokens` and `input_pos`. // TODO more arguments might need to be excluded - size_t num_layers = (TfLiteSignatureRunnerGetInputCount(runner) - 2) / 2; + size_t num_layers = (runner->input_size() - 2) / 2; if (num_layers == 0) { return {}; } @@ -273,7 +267,7 @@ kv_cache_t LLMPipeline::BuildKVCache(TfLiteInterpreter *interpreter) { std::string k_cache_name = "kv_cache_k_" + std::to_string(i); std::string v_cache_name = "kv_cache_v_" + std::to_string(i); // We are assuming K and V tensors are of the same shape. - TfLiteTensor* tensor = TfLiteSignatureRunnerGetInputTensor(runner, k_cache_name.c_str()); + TfLiteTensor* tensor = runner->input_tensor(k_cache_name.c_str()); size_t count = tensor->bytes / sizeof(float); kv_cache.emplace(k_cache_name, std::vector>(count, 0.0)); @@ -284,33 +278,30 @@ kv_cache_t LLMPipeline::BuildKVCache(TfLiteInterpreter *interpreter) { return kv_cache; } -void LLMPipeline::PrepareRunner(TfLiteSignatureRunner* runner, kv_cache_t& kv_cache) { +void LLMPipeline::PrepareRunner(tflite::SignatureRunner* runner, kv_cache_t& kv_cache) { for (auto& [name, cache] : kv_cache) { - TfLiteCustomAllocation allocation = { - .data = static_cast(cache.data()), - .bytes = cache.size() * sizeof(float)}; + TfLiteCustomAllocation allocation = {.data = static_cast(cache.data()), .bytes = cache.size() * sizeof(float)}; // Both input and output tensors are set to the same buffer. Not all // delegates support this in-place update. For those cases, we need to do // a ping-pong buffer and update the pointers between inference calls. - MINIMAL_CHECK_VOID(TfLiteSignatureRunnerSetInputCustomAllocation(runner, name.c_str(), &allocation)); - MINIMAL_CHECK_VOID(TfLiteSignatureRunnerSetInputCustomAllocation(runner, name.c_str(), &allocation)); + MINIMAL_CHECK_VOID(runner->SetCustomAllocationForInputTensor(name.c_str(), allocation) == kTfLiteOk); + MINIMAL_CHECK_VOID(runner->SetCustomAllocationForOutputTensor(name.c_str(), allocation) == kTfLiteOk); } - MINIMAL_CHECK_VOID(TfLiteSignatureRunnerAllocateTensors(runner)); + MINIMAL_CHECK_VOID(runner->AllocateTensors() == kTfLiteOk); } -TfLiteSignatureRunner *LLMPipeline::GetPrefillRunner(TfLiteInterpreter* interpreter, std::size_t num_input_tokens, kv_cache_t& kv_cache) { +tflite::SignatureRunner *LLMPipeline::GetPrefillRunner(tflite::Interpreter* interpreter, std::size_t num_input_tokens, kv_cache_t& kv_cache) { // Find the prefill signature length that best matches the input token size. - TfLiteSignatureRunner* runner = nullptr; + tflite::SignatureRunner* runner = nullptr; //int best_seq_size = -1; size_t delta = std::numeric_limits::max(); - for (int32_t i = 0; i < TfLiteInterpreterGetSignatureCount(interpreter); i++) { - std::string key (TfLiteInterpreterGetSignatureKey(interpreter, i)); - if (key.find("prefill") == std::string::npos) continue; - TfLiteTensor* input_pos = TfLiteSignatureRunnerGetInputTensor(TfLiteInterpreterGetSignatureRunner(interpreter, key.c_str()), "input_pos"); + for (const std::string* key : interpreter->signature_keys()) { + if (key->find("prefill") == std::string::npos) continue; + TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())->input_tensor("input_pos"); // The expected shape for input position is [Seq]. size_t seq_size = input_pos->dims->data[0]; if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) { - runner = TfLiteInterpreterGetSignatureRunner(interpreter, key.c_str()); + runner = interpreter->GetSignatureRunner(key->c_str()); //best_seq_size = seq_size; delta = seq_size - num_input_tokens; } @@ -320,8 +311,8 @@ TfLiteSignatureRunner *LLMPipeline::GetPrefillRunner(TfLiteInterpreter* interpre return runner; } -TfLiteSignatureRunner *LLMPipeline::GetDecodeRunner(TfLiteInterpreter* interpreter, kv_cache_t& kv_cache) { - TfLiteSignatureRunner* runner = TfLiteInterpreterGetSignatureRunner(interpreter, "decode"); +tflite::SignatureRunner *LLMPipeline::GetDecodeRunner(tflite::Interpreter* interpreter, kv_cache_t& kv_cache) { + tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); MINIMAL_CHECK_PTR(runner != nullptr); PrepareRunner(runner, kv_cache); return runner; @@ -350,43 +341,6 @@ int LLMPipeline::GreedySampler(const TfLiteTensor* logits) { return max_index; } -TfLiteRegistration* LLMPipeline::GetGenAIGenerateOp() { - static tflite::MutableOpResolver resolver; - - tflite::ops::custom::GenAIOpsRegisterer(&resolver); - - const TfLiteRegistration* reg = resolver.FindOp("GEN_AI_GENERATE", /*version=*/1); - - if (!reg) { - LOG(ERROR) << "Could not find GEN_AI_GENERATE op." << std::endl; - return nullptr; - } - - static TfLiteRegistration reg_copy = *reg; - return ®_copy; -} - -bool LLMPipeline::TfLiteSignatureRunnerSetInputCustomAllocation(TfLiteSignatureRunner* runner, const char* input_name, const TfLiteCustomAllocation* allocation) { - if (!runner || !input_name || !allocation) return false; - auto* cpp_runner = reinterpret_cast(runner); - return cpp_runner - ->SetCustomAllocationForInputTensor(input_name, *allocation) == - kTfLiteOk; -} - -bool LLMPipeline::TfLiteSignatureRunnerSetOutputCustomAllocation(TfLiteSignatureRunner* runner, const char* output_name, const TfLiteCustomAllocation* allocation) { - if (!runner || !output_name || !allocation) return false; - auto* cpp_runner = reinterpret_cast(runner); - return cpp_runner - ->SetCustomAllocationForOutputTensor(output_name, *allocation) == - kTfLiteOk; -} - -bool LLMPipeline::TfLiteSignatureRunnerAllocateTensors(TfLiteSignatureRunner* runner) { - if (!runner) return false; - return reinterpret_cast(runner)->AllocateTensors() == kTfLiteOk; -} - #ifdef __cplusplus }; #endif // __cplusplus diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 315e9ad6d..ad29c6b33 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -19,16 +19,14 @@ limitations under the License. #include "flutter/cpp/c/type.h" #include "pipeline.h" -#include "tensorflow/lite/c/c_api.h" -#include "tensorflow/lite/c/c_api_experimental.h" + +#include "src/sentencepiece_processor.h" #include "tensorflow/lite/experimental/genai/genai_ops.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/signature_runner.h" -#include "tensorflow/lite/util.h" -#include "src/sentencepiece_processor.h" #include "tensorflow/core/platform/logging.h" @@ -79,12 +77,12 @@ struct LLMBackendData { const char *name = "TFLite"; const char *vendor = "Google"; const char *accelerator = "CPU"; - TfLiteModel *model{nullptr}; + tflite::FlatBufferModel *model{nullptr}; sentencepiece::SentencePieceProcessor *sp_processor{nullptr}; //TfLiteInterpreterOptions *options{}; TODO use this to allow different delegates other than CPU? - TfLiteInterpreter *interpreter{}; - TfLiteSignatureRunner *prefill_runner{nullptr}; - TfLiteSignatureRunner *decode_runner{nullptr}; + tflite::Interpreter *interpreter{}; + tflite::SignatureRunner *prefill_runner{nullptr}; + tflite::SignatureRunner *decode_runner{nullptr}; kv_cache_t kv_cache; //std::string input_prompt; std::vector prompt_tokens; @@ -94,6 +92,18 @@ struct LLMBackendData { int stop_token_id = -1; std::string output; + LLMBackendData(){} + + ~LLMBackendData() { + // Runners are owned by interpreter and therefore don't need to be deleted + delete sp_processor; + delete interpreter; + delete model; + } + + LLMBackendData(const LLMBackendData&) = delete; + LLMBackendData& operator=(const LLMBackendData&) = delete; + // uint32_t real_batch_size = 1; //std::unique_ptr executer; // int32_t original_tensor_size = 0; @@ -157,17 +167,13 @@ class LLMPipeline : public Pipeline { void backend_release_buffer(void *p) override; private: - TfLiteInterpreter *BuildInterpreter(TfLiteModel *model, int num_threads); - kv_cache_t BuildKVCache(TfLiteInterpreter *interpreter); - void PrepareRunner(TfLiteSignatureRunner *runner, kv_cache_t &kv_cache); - TfLiteSignatureRunner *GetPrefillRunner(TfLiteInterpreter *interpreter, std::size_t num_input_tokens, kv_cache_t &kv_cache); - TfLiteSignatureRunner *GetDecodeRunner(TfLiteInterpreter *interpreter, kv_cache_t &kv_cache); + tflite::Interpreter *BuildInterpreter(tflite::FlatBufferModel *model, int num_threads); + kv_cache_t BuildKVCache(tflite::Interpreter *interpreter); + void PrepareRunner(tflite::SignatureRunner *runner, kv_cache_t &kv_cache); + tflite::SignatureRunner *GetPrefillRunner(tflite::Interpreter *interpreter, std::size_t num_input_tokens, kv_cache_t &kv_cache); + tflite::SignatureRunner *GetDecodeRunner(tflite::Interpreter *interpreter, kv_cache_t &kv_cache); sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor(std::string path); int GreedySampler(const TfLiteTensor *logits); - TfLiteRegistration* GetGenAIGenerateOp(); - bool TfLiteSignatureRunnerSetInputCustomAllocation(struct TfLiteSignatureRunner* runner, const char* input_name, const struct TfLiteCustomAllocation* allocation); - bool TfLiteSignatureRunnerSetOutputCustomAllocation(struct TfLiteSignatureRunner* runner, const char* output_name, const struct TfLiteCustomAllocation* allocation); - bool TfLiteSignatureRunnerAllocateTensors(TfLiteSignatureRunner* runner); }; diff --git a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc index 62a6a18bc..64ce37147 100644 --- a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc +++ b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc @@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include "single_model_pipeline.h" #include "stable_diffusion_pipeline.h" +#include "llm_pipeline.h" #include "tensorflow/core/platform/logging.h" #include "tflite_settings_android.h" #include "tflite_settings_apple.h" @@ -37,11 +39,17 @@ extern "C" { std::unique_ptr pipeline; void init_pipeline(const char *pipeline_type) { + //TODO use a switch/case bool sd_pipeline = (strcmp(pipeline_type, "StableDiffusionPipeline") == 0); + bool llm_pipeline = (strcmp(pipeline_type, "LLMPipeline") == 0); if (sd_pipeline) { LOG(INFO) << "Initializing StableDiffusionPipeline"; pipeline = std::make_unique(); - } else { + } else if (llm_pipeline) { + LOG(INFO) << "Initializing LLMPipeline"; + pipeline = std::make_unique(); + } + else { LOG(INFO) << "Initializing SingleModelPipeline"; pipeline = std::make_unique(); } @@ -145,6 +153,7 @@ bool mlperf_backend_matches_hardware(const char **not_allowed_message, mlperf_backend_ptr_t mlperf_backend_create( const char *model_path, mlperf_backend_configuration_t *configs, const char *native_lib_path) { + LOG(INFO) << "Using TfLite " << TfLiteVersion() << " With Schema " << TfLiteSchemaVersion() << std::endl; const char *pipeline_type = ""; for (int i = 0; i < configs->count; ++i) { if (strcmp(configs->keys[i], "pipeline") == 0) { From 24ad1d55d351b5f4166e9513751db7c432c90b7c Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 19 Aug 2025 00:34:35 +0300 Subject: [PATCH 05/74] basic flutter app support for icon and dataset --- flutter/assets/icons/ic_task_llm.svg | 113 +++++++++++----------- flutter/assets/tasks.pbtxt | 28 +++++- flutter/cpp/flutter/dart_run_benchmark.cc | 5 + flutter/lib/app_constants.dart | 1 + flutter/lib/ui/icons.dart | 4 + 5 files changed, 93 insertions(+), 58 deletions(-) diff --git a/flutter/assets/icons/ic_task_llm.svg b/flutter/assets/icons/ic_task_llm.svg index ffede9c72..b750915f7 100644 --- a/flutter/assets/icons/ic_task_llm.svg +++ b/flutter/assets/icons/ic_task_llm.svg @@ -22,78 +22,77 @@ inkscape:pagecheckerboard="1" inkscape:deskcolor="#d1d1d1" inkscape:zoom="0.58212013" - inkscape:cx="409.70924" + inkscape:cx="408.85032" inkscape:cy="540.26649" - inkscape:window-width="1920" - inkscape:window-height="1052" + inkscape:window-width="1025" + inkscape:window-height="1080" inkscape:window-x="0" inkscape:window-y="0" - inkscape:window-maximized="1" + inkscape:window-maximized="0" inkscape:current-layer="svg3" /> - - - - - - - - - - + + id="path3" + inkscape:path-effect="#path-effect7" + transform="matrix(0.88221877,0.40039021,-0.40039021,0.88221877,315.93157,-156.70199)" /> + id="defs3"> + + + diff --git a/flutter/assets/tasks.pbtxt b/flutter/assets/tasks.pbtxt index 8cb92d03e..54c59a859 100644 --- a/flutter/assets/tasks.pbtxt +++ b/flutter/assets/tasks.pbtxt @@ -348,13 +348,39 @@ task { min_duration: 60 max_duration: 300 } + quick { + min_query_count: 128 + min_duration: 10 + max_duration: 40 + } + rapid { + min_query_count: 64 + min_duration: 6 + max_duration: 60 + } } datasets { type: MMLU + full { + name: "TinyMMLU prompt set for LLM" + input_path: "https://thee.dev/mlc/data.tfrecord" + input_checksum: "b564d2c228a867148fa7d6df415a0368" + groundtruth_path: "" + groundtruth_checksum: "" + } + lite { + name: "TinyMMLU prompt set for LLM" + input_path: "https://thee.dev/mlc/data.tfrecord" + input_checksum: "b564d2c228a867148fa7d6df415a0368" + groundtruth_path: "" + groundtruth_checksum: "" + } tiny { name: "TinyMMLU prompt set for LLM" - input_path: "https://thee.dev/mlc/data.tfrecord" #TODO placeholder + input_path: "https://thee.dev/mlc/data.tfrecord" input_checksum: "b564d2c228a867148fa7d6df415a0368" + groundtruth_path: "" + groundtruth_checksum: "" } } model { diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index 63aa4dfb8..9030496c9 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -14,6 +14,7 @@ #include "flutter/cpp/datasets/imagenet.h" #include "flutter/cpp/datasets/snu_sr.h" #include "flutter/cpp/datasets/squad.h" +#include "flutter/cpp/datasets/mmlu_gen.h" #include "flutter/cpp/mlperf_driver.h" #include "flutter/cpp/proto/backend_setting.pb.h" #include "flutter/cpp/proto/mlperf_task.pb.h" @@ -105,6 +106,10 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( backend.get(), in->dataset_data_path, in->dataset_groundtruth_path, in->output_dir); break; + case ::mlperf::mobile::DatasetConfig::MMLU: + dataset = std::make_unique<::mlperf::mobile::MmluGen>( + backend.get(), in->dataset_data_path); + break; default: return nullptr; } diff --git a/flutter/lib/app_constants.dart b/flutter/lib/app_constants.dart index b0849c36f..56ddc681e 100644 --- a/flutter/lib/app_constants.dart +++ b/flutter/lib/app_constants.dart @@ -25,6 +25,7 @@ class BenchmarkId { static const imageClassificationV2 = 'image_classification_v2'; static const imageClassificationOfflineV2 = 'image_classification_offline_v2'; static const stableDiffusion = 'stable_diffusion'; + static const llm = 'llm'; // The sort order of this list will be used in the UI static const allIds = [ diff --git a/flutter/lib/ui/icons.dart b/flutter/lib/ui/icons.dart index 524813430..e47263e61 100644 --- a/flutter/lib/ui/icons.dart +++ b/flutter/lib/ui/icons.dart @@ -28,6 +28,8 @@ class AppIcons { _pSvg('ic_task_super_resolution.svg'); static final SvgPicture stableDiffusion = _pSvg('ic_task_stable_diffusion.svg'); + static final SvgPicture llm = + _pSvg('ic_task_llm.svg'); static final SvgPicture imageClassificationWhite = _pSvg('ic_task_image_classification_white.svg'); @@ -70,6 +72,7 @@ class BenchmarkIcons { BenchmarkId.stableDiffusion: AppIcons.stableDiffusion, BenchmarkId.imageClassificationOfflineV2: AppIcons.imageClassificationOffline, + BenchmarkId.llm: AppIcons.llm, }; static final lightSet = { @@ -81,6 +84,7 @@ class BenchmarkIcons { BenchmarkId.stableDiffusion: AppIcons.stableDiffusionWhite, BenchmarkId.imageClassificationOfflineV2: AppIcons.imageClassificationOfflineWhite, + BenchmarkId.llm: AppIcons.llm, }; static Widget getDarkIcon(String benchmarkId) => From aa094399521235067a4ef06a3516f6aab9267d78 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 19 Aug 2025 09:04:41 +0300 Subject: [PATCH 06/74] added linux x86_64 config for internal testing --- .bazelrc | 24 ++++++++++++++++++++++++ flutter/cpp/binary/cmdline-docker.mk | 8 +++++++- flutter/cpp/binary/cmdline.mk | 26 +++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index ccbb17890..957da8fbd 100644 --- a/.bazelrc +++ b/.bazelrc @@ -10,6 +10,7 @@ build --spawn_strategy=standalone # This flag is required by tensorflow common --experimental_repo_remote_exec +# Without these, tensorflow complains about lack of CUDA library. common --repo_env=TF_NEED_CUDA=0 common --repo_env=TF_NEED_ROCM=0 @@ -21,6 +22,7 @@ build:verbose_logs --output_filter= # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. build:android --copt=-w +build:linux --copt=-w build:ios --copt=-w build:windows --copt=/W0 @@ -28,6 +30,8 @@ build:windows --copt=/W0 build --cxxopt=-std=c++17 build:android --cxxopt=-std=c++17 build:android --host_cxxopt=-std=c++17 +build:linux --cxxopt=-std=c++17 +build:linux --host_cxxopt=-std=c++17 build:ios --cxxopt=-std=c++17 build:ios --host_cxxopt=-std=c++17 build:ios --cxxopt=-xobjective-c++ @@ -44,6 +48,26 @@ build:android_x86_64 --config=android build:android_x86_64 --cpu=x86_64 build:android_x86_64 --fat_apk_cpu=x86_64 +# Linux configs +build:linux_x86_64 --config=linux +build:linux_x86_64 --cpu=k8 + +# These are neccessary because the compiler that bazel 6.3 uses doesn't support VNNI +build:linux_x86_64 --define=xnn_enable_avx=false +build:linux_x86_64 --define=xnn_enable_avx2=false +build:linux_x86_64 --define=xnn_enable_avx512=false +build:linux_x86_64 --define=xnn_enable_avx512fp16=false +build:linux_x86_64 --define=xnn_enable_avxvnni=false +build:linux_x86_64 --define=xnn_enable_avxvnniint8=false +build:linux_x86_64 --define=xnn_enable_vnni=false + +# Optional, enable for debugging or compilation errors +#build:linux_x86_64 --action_env=CC=gcc +#build:linux_x86_64 --action_env=CXX=g++ +#build:linux_x86_64 --strip=never +#build:linux_x86_64 --copt=-fno-omit-frame-pointer +#build:linux_x86_64 --linkopt=-fno-omit-frame-pointer + # iOS configs build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing diff --git a/flutter/cpp/binary/cmdline-docker.mk b/flutter/cpp/binary/cmdline-docker.mk index 63e92f38c..99e103668 100644 --- a/flutter/cpp/binary/cmdline-docker.mk +++ b/flutter/cpp/binary/cmdline-docker.mk @@ -17,4 +17,10 @@ docker/cmdline/android/release: flutter/android/docker/image MSYS2_ARG_CONV_EXCL="*" docker run \ ${flutter_common_docker_flags} \ - make cmdline/android/bins/release \ No newline at end of file + make cmdline/android/bins/release + +.PHONY: docker/cmdline/linux/release +docker/cmdline/linux/release: flutter/android/docker/image + MSYS2_ARG_CONV_EXCL="*" docker run \ + ${flutter_common_docker_flags} \ + make cmdline/linux/bins/release diff --git a/flutter/cpp/binary/cmdline.mk b/flutter/cpp/binary/cmdline.mk index f920044ae..5078a2e38 100644 --- a/flutter/cpp/binary/cmdline.mk +++ b/flutter/cpp/binary/cmdline.mk @@ -16,6 +16,7 @@ include flutter/cpp/binary/cmdline-docker.mk cmdline/android/bins/release: cmdline/android/libs/deps cmdline/android/bins/build cmdline/android/bins/copy +cmdline/linux/bins/release: cmdline/linux/bins/build cmdline/linux/bins/copy .PHONY: cmdline/android/libs/deps cmdline/android/libs/deps: @@ -52,6 +53,29 @@ cmdline/android/bins/copy: @# macos doesn't support --recursive flag chmod -R 777 ${cmdline_android_bin_release_path} +.PHONY: cmdline/linux/bins/build +cmdline/linux/bins/build: + bazel ${BAZEL_OUTPUT_ROOT_ARG} ${proxy_bazel_args} ${sonar_bazel_startup_options} \ + build ${BAZEL_CACHE_ARG} ${bazel_links_arg} ${sonar_bazel_build_args} \ + --config=linux_x86_64 \ + ${backend_tflite_android_target} \ + //flutter/cpp/flutter:libbackendbridge.so \ + //flutter/cpp/binary:main + +cmdline_linux_bin_release_path=output/linux/cmdline +.PHONY: cmdline/linux/bins/copy +cmdline/linux/bins/copy: + rm -rf ${cmdline_linux_bin_release_path} + mkdir -p ${cmdline_linux_bin_release_path} + @# macos doesn't support --target-directory flag + cp -f \ + ${backend_tflite_android_files} \ + ${BAZEL_LINKS_PREFIX}bin/flutter/cpp/flutter/libbackendbridge.so \ + ${BAZEL_LINKS_PREFIX}bin/flutter/cpp/binary/main \ + ${cmdline_linux_bin_release_path} + @# macos doesn't support --recursive flag + chmod -R 777 ${cmdline_linux_bin_release_path} + windows_cmdline_folder=output/windows/cmdline .PHONY: cmdline/windows/bins cmdline/windows/bins: @@ -92,4 +116,4 @@ cmdline/windows/prepare-dlls: .PHONY: cmdline/windows/copy-dlls cmdline/windows/copy-dlls: currentDir=$$(pwd) && cd "${msvc_arm_dlls_path}" && \ - cp --target-directory $$currentDir/${windows_cmdline_folder} ${msvc_arm_dlls_list} \ No newline at end of file + cp --target-directory $$currentDir/${windows_cmdline_folder} ${msvc_arm_dlls_list} From 84b164e26d9284b78622fe271c5e0e7698390f71 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 26 Aug 2025 03:38:30 +0300 Subject: [PATCH 07/74] updated bazel config to use SSE/MMX instructions --- .bazelrc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 957da8fbd..2b807cc72 100644 --- a/.bazelrc +++ b/.bazelrc @@ -51,15 +51,17 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Linux configs build:linux_x86_64 --config=linux build:linux_x86_64 --cpu=k8 +# Not required, but enables the proper SSE/MMX instructions per CPU +build:linux_x86_64 --copt=-march=native # These are neccessary because the compiler that bazel 6.3 uses doesn't support VNNI -build:linux_x86_64 --define=xnn_enable_avx=false -build:linux_x86_64 --define=xnn_enable_avx2=false -build:linux_x86_64 --define=xnn_enable_avx512=false +#build:linux_x86_64 --define=xnn_enable_avx=false +#build:linux_x86_64 --define=xnn_enable_avx2=false +#build:linux_x86_64 --define=xnn_enable_avx512=false build:linux_x86_64 --define=xnn_enable_avx512fp16=false -build:linux_x86_64 --define=xnn_enable_avxvnni=false +#build:linux_x86_64 --define=xnn_enable_avxvnni=false build:linux_x86_64 --define=xnn_enable_avxvnniint8=false -build:linux_x86_64 --define=xnn_enable_vnni=false +#build:linux_x86_64 --define=xnn_enable_vnni=false # Optional, enable for debugging or compilation errors #build:linux_x86_64 --action_env=CC=gcc From d57040cca5420c642516fc7063f6ffa459bf06da Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 26 Aug 2025 03:42:01 +0300 Subject: [PATCH 08/74] fixed incorrect answer format and compression --- .../datasets/mmlu_utils/generate_tfrecords.py | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py b/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py index 793a513d0..2e7bfcf0d 100644 --- a/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py +++ b/flutter/cpp/datasets/mmlu_utils/generate_tfrecords.py @@ -3,30 +3,46 @@ import argparse def parse_args(): - parser = argparse.ArgumentParser(description="Convert a CSV of LLM prompts to TFRecord format.") - parser.add_argument('--input_file', type=str, required=True, help="Path to the input CSV file.") + parser = argparse.ArgumentParser(description="Convert a Parquet of LLM prompts to TFRecord format.") + parser.add_argument('--input_file', type=str, required=True, help="Path to the input Parquet (.parquet) file.") parser.add_argument('--output_file', type=str, required=True, help="Path to the output TFRecord file.") return parser.parse_args() def map_answer(num): - return {1: "A", 2: "B", 3: "C", 4: "D"}.get(num, "X") # Use 'X' as fallback + return {0: "A", 1: "B", 2: "C", 3: "D"}.get(num, "X") # Use 'X' as fallback def create_example(input_text, answer_letter): return tf.train.Example(features=tf.train.Features(feature={ - "input": tf.train.Feature(bytes_list=tf.train.BytesList(value=[input_text.encode()])), + "input": tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(input_text).encode()])), "answer": tf.train.Feature(bytes_list=tf.train.BytesList(value=[answer_letter.encode()])), })) def main(): args = parse_args() - df = pd.read_csv(args.input_file) + + # Read Parquet (requires 'pyarrow' or 'fastparquet') + try: + df = pd.read_parquet(args.input_file) + except ImportError as e: + raise ImportError( + "Reading Parquet requires 'pyarrow' or 'fastparquet'. " + "Install one, e.g. `pip install pyarrow`." + ) from e if "input_formatted" not in df.columns or "answer" not in df.columns: - raise ValueError("CSV must contain 'input_formatted' and 'answer' columns.") + raise ValueError("Parquet must contain 'input_formatted' and 'answer' columns.") + + # Robustly map numeric answers to letters + def to_letter(x): + try: + return map_answer(int(x)) + except (ValueError, TypeError): + return "X" - df["answer_letter"] = df["answer"].map(map_answer) + df["answer_letter"] = df["answer"].apply(to_letter) - with tf.io.TFRecordWriter(args.output_file) as writer: + options = tf.io.TFRecordOptions(compression_type="ZLIB") + with tf.io.TFRecordWriter(args.output_file, options=options) as writer: for _, row in df.iterrows(): example = create_example(row["input_formatted"], row["answer_letter"]) writer.write(example.SerializeToString()) From f9e40a5bb3ee34dc372c6e2cbf7a191ad58c4c8c Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 26 Aug 2025 03:51:03 +0300 Subject: [PATCH 09/74] got pipeline and dataset to produce proper results + fixed issues where pipeline cannot handle an input size larger than the max prefill size --- flutter/cpp/datasets/mmlu_gen.cc | 24 ++++++++--- flutter/cpp/datasets/mmlu_gen.h | 7 +++- .../cpp/backend_tflite/llm_pipeline.cc | 42 +++++++++++++++---- .../cpp/backend_tflite/llm_pipeline.h | 3 +- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 22933bd14..742237da8 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -8,8 +8,9 @@ namespace mlperf { namespace mobile { -MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) +MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord/*, const std::string& input_sppp*/) : sample_reader_(input_tfrecord), Dataset(backend) { + std::cout << "MMLUT-DATASET: " << "Initializing with TFRecord " << input_tfrecord << " with sample size " << std::to_string(sample_reader_.Size()) << std::endl; // Load all TFRecord samples into memory for (size_t i = 0; i < sample_reader_.Size(); i++) { tensorflow::tstring record = sample_reader_.ReadRecord(i); @@ -20,13 +21,17 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) auto sample = std::make_unique(); sample->input = input; - sample->correct_answer = answer; + sample->answer = answer; + + std::cout << "MMLUT-DATASET: " << "Loading TFRecord Data index " << std::to_string(i) << " with answer {" << answer << "}" << std::endl; samples_.push_back(std::move(sample)); } + //LoadSentencePieceProcessor(input_sppp); } void MmluGen::LoadSamplesToRam(const std::vector& samples) { + std::cout << "MMLUT-DATASET: " << "Loading Samples..." << std::endl; for (auto id : samples) { loaded_sample_ids_.insert(id); } @@ -39,6 +44,7 @@ void MmluGen::UnloadSamplesFromRam(const std::vector& samples) } std::vector MmluGen::GetData(int sample_idx) { + std::cout << "MMLUT-DATASET: " << "Getting data at index " << std::to_string(sample_idx) << " (Answer is " << samples_[sample_idx]->answer << ")" << std::endl; std::vector data; if (sample_idx < samples_.size()) { data.push_back(reinterpret_cast(const_cast(samples_[sample_idx]->input.c_str()))); @@ -50,14 +56,16 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vec if (sample_idx >= samples_.size() || outputs.empty()) return {0}; const char* prediction = reinterpret_cast(outputs[0]); - char predicted_char = prediction[0]; // Assume first token is the answer (e.g., 'A', 'B', ...) - - const std::string& correct = samples_[sample_idx]->correct_answer; + char predicted_char = prediction[1]; // Assume second token is the answer because of whitespace (e.g., 'A', 'B', ...) + std::cout << "MMLUT-DATASET: " << "Predicted answer: " << predicted_char << std::endl; + const std::string& correct = samples_[sample_idx]->answer; bool is_correct = (predicted_char == correct[0]); total_++; if (is_correct) correct_++; + std::cout << "MMLUT-DATASET: " << "Accuracy: " << std::to_string(correct_) << "/" << std::to_string(total_) << std::endl; + return {static_cast(is_correct)}; } @@ -74,5 +82,11 @@ std::string MmluGen::ComputeAccuracyString() { return "Accuracy: " + std::to_string(acc * 100.0f) + "%"; } +//void MmluGen::loadSentencePieceProcessor(std::string path) { +// std::ifstream input(path, std::ios::binary); +// std::string serialized_proto = std::string(std::istreambuf_iterator(input), std::istreambuf_iterator()); +// if(!sp_processor->LoadFromSerializedProto(serialized_proto).ok()) LOG(FATAL) << "Could not load SP Processor"; +//} + } // namespace mobile } // namespace mlperf diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 4844a2be8..9773d2338 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -10,6 +10,7 @@ #include #include +//#include "src/sentencepiece_processor.h" #include "flutter/cpp/dataset.h" #include "flutter/cpp/datasets/squad_utils/tfrecord_reader.h" @@ -38,14 +39,18 @@ class MmluGen : public Dataset { std::string ComputeAccuracyString() override; + private: + //void loadSentencePieceProcessor(std::string path); + const std::string name_ = "MmluGen"; TFRecordReader sample_reader_; + //sentencepiece::SentencePieceProcessor sp_processor; struct PromptSample { std::string input; - std::string correct_answer; + std::string answer; }; std::vector> samples_; diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index e312f28a7..5bbf31efb 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -98,7 +98,6 @@ const char *LLMPipeline::backend_vendor_name(mlperf_backend_ptr_t backend_ptr) { return backend_data->vendor; } -// TODO: Return the name of the accelerator. const char *LLMPipeline::backend_accelerator_name(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; return backend_data->accelerator; @@ -129,6 +128,7 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt int max_seq_size = prefill_input->dims->data[1]; int kv_cache_max_size = kv_cache_k_0->dims->data[1]; int prefill_seq_size = std::min(static_cast(backend_data->prompt_tokens.size()), max_seq_size); + int decode_input_size = std::max(static_cast(backend_data->prompt_tokens.size()) - max_seq_size, 0); // Making sure the decode seq isn't negative for later subtractions std::memset(prefill_input->data.i32, 0, prefill_input->bytes); std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes); @@ -139,25 +139,38 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); - int decode_steps = kv_cache_max_size - prefill_seq_size; + // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle + int decode_steps = std::min(backend_data->max_output_tokens+decode_input_size, kv_cache_max_size - prefill_seq_size); MINIMAL_CHECK(decode_steps > 0); std::vector output_tokens; - output_tokens.reserve(decode_steps); + output_tokens.reserve(decode_steps - decode_input_size); int next_token = backend_data->prompt_tokens[prefill_seq_size - 1]; int next_position = prefill_seq_size - 1; + int next_input_position = 0; // only used if we need to put input in the decode runner for (int i = 0; i < decode_steps; ++i) { decode_input->data.i32[0] = next_token; decode_input_pos->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); - next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); - output_tokens.push_back(next_token); + + // if the input is larger than the maximum prefill size, the decode step takes the rest, without checking logits + if (decode_input_size > 0 && next_input_position < decode_input_size) { + next_token = backend_data->prompt_tokens[max_seq_size + next_input_position++]; + //LOG(INFO) << "position is " << std::to_string(next_input_position) << "/" << std::to_string(decode_seq_size) << std::endl; + } + else { + next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); + LOG(INFO) << backend_data->sp_processor->IdToPiece(next_token) << std::endl; + if (next_token == backend_data->stop_token_id) break; + output_tokens.push_back(next_token); + } next_position += 1; - if (next_token == backend_data->stop_token_id) break; } MINIMAL_CHECK(backend_data->sp_processor->Decode(output_tokens, &backend_data->output).ok()); + LOG(INFO) << "Output: " << backend_data->output << std::endl; + return MLPERF_SUCCESS; } @@ -167,8 +180,9 @@ mlperf_status_t LLMPipeline::backend_flush_queries(mlperf_backend_ptr_t backend_ } // Return the number of inputs of the model. +// Only 1 input needs to be provided, which is the tokens themselves, the other inputs are handled by the pipeline int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { - return 2; + return 1; } // Return the type of the ith input. @@ -181,7 +195,10 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; std::string prompt = std::string(static_cast(data)); - MINIMAL_CHECK(backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens).ok()); //TODO + + + LOG(INFO) << "Input: " << prompt << std::endl; + MINIMAL_CHECK(backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens).ok()); if (!backend_data->start_token.empty()) { backend_data->prompt_tokens.insert(backend_data->prompt_tokens.begin(), backend_data->sp_processor->PieceToId((backend_data->start_token))); @@ -295,17 +312,26 @@ tflite::SignatureRunner *LLMPipeline::GetPrefillRunner(tflite::Interpreter* inte tflite::SignatureRunner* runner = nullptr; //int best_seq_size = -1; size_t delta = std::numeric_limits::max(); + size_t max_prefill_size = 0; + std::string max_prefill_key = std::string(""); for (const std::string* key : interpreter->signature_keys()) { if (key->find("prefill") == std::string::npos) continue; TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())->input_tensor("input_pos"); // The expected shape for input position is [Seq]. size_t seq_size = input_pos->dims->data[0]; + //TODO this could be else maybe? + if (seq_size > max_prefill_size) { + max_prefill_size = seq_size; + max_prefill_key = std::string(key->c_str()); + } if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) { runner = interpreter->GetSignatureRunner(key->c_str()); //best_seq_size = seq_size; delta = seq_size - num_input_tokens; } } + //fallback to maximum possible size if a runner is not found (most likely because the seq_size is larger than max_prefill_size) + if (!runner && max_prefill_key != "") runner = interpreter->GetSignatureRunner(max_prefill_key.c_str()); MINIMAL_CHECK_PTR(runner != nullptr); PrepareRunner(runner, kv_cache); return runner; diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index ad29c6b33..955cfc017 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -86,7 +86,8 @@ struct LLMBackendData { kv_cache_t kv_cache; //std::string input_prompt; std::vector prompt_tokens; - uint8_t threads = 1; + uint8_t threads = 30; + uint32_t max_output_tokens = 2; std::string start_token = ""; std::string end_token = ""; int stop_token_id = -1; From 057c9f8ae055054a44b501b8ad935b849be2f3b1 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 1 Sep 2025 07:07:40 +0300 Subject: [PATCH 10/74] added support for loadgen's token based performance measurement + implemented performance benchmark for LLM pipeline --- flutter/cpp/backend.h | 3 + flutter/cpp/backends/external.cc | 2 + flutter/cpp/backends/external.h | 8 ++ flutter/cpp/binary/main.cc | 2 +- flutter/cpp/dataset.h | 4 + flutter/cpp/datasets/mmlu_gen.cc | 10 +- flutter/cpp/datasets/mmlu_gen.h | 5 + flutter/cpp/mlperf_driver.cc | 36 +++++- flutter/cpp/mlperf_driver.h | 3 +- .../cpp/backend_tflite/llm_pipeline.cc | 122 ++++++++++++------ .../cpp/backend_tflite/llm_pipeline.h | 15 ++- .../cpp/backend_tflite/pipeline.h | 5 +- .../backend_tflite/single_model_pipeline.h | 4 + .../stable_diffusion_pipeline.h | 4 + .../cpp/backend_tflite/tflite_c.cc | 4 + 15 files changed, 172 insertions(+), 55 deletions(-) diff --git a/flutter/cpp/backend.h b/flutter/cpp/backend.h index deb8fdc89..e91454b55 100644 --- a/flutter/cpp/backend.h +++ b/flutter/cpp/backend.h @@ -44,6 +44,9 @@ class Backend { // Accelerator name. virtual const std::string& AcceleratorName() const = 0; + // Run inference for token based input (such as LLM prompt). Only needed for LLMs currently. + virtual void IssueFirstTokenQuery() = 0; + // Run inference for a sample. Inputs is already set by SetInputs. virtual void IssueQuery() = 0; diff --git a/flutter/cpp/backends/external.cc b/flutter/cpp/backends/external.cc index 3675e4943..363c6863e 100644 --- a/flutter/cpp/backends/external.cc +++ b/flutter/cpp/backends/external.cc @@ -159,6 +159,8 @@ BackendFunctions::BackendFunctions(const std::string& lib_path) { destroy = reinterpret_cast(GetSymbol("mlperf_backend_delete")); + issue_first_token_query = reinterpret_cast( + GetSymbol("mlperf_backend_issue_first_token_query")); issue_query = reinterpret_cast( GetSymbol("mlperf_backend_issue_query")); flush_queries = reinterpret_cast( diff --git a/flutter/cpp/backends/external.h b/flutter/cpp/backends/external.h index 12e17357e..1d9d4e02c 100644 --- a/flutter/cpp/backends/external.h +++ b/flutter/cpp/backends/external.h @@ -47,6 +47,8 @@ struct BackendFunctions { using AcceleratorNamePtr = std::add_pointer::type; using BackendDeletePtr = std::add_pointer::type; + using IssueFirstTokenQueryPtr = + std::add_pointer::type; using IssueQueryPtr = std::add_pointer::type; using FlushQueriesPtr = @@ -78,6 +80,7 @@ struct BackendFunctions { AcceleratorNamePtr accelerator_name{nullptr}; BackendDeletePtr destroy{nullptr}; + IssueFirstTokenQueryPtr issue_first_token_query{nullptr}; IssueQueryPtr issue_query{nullptr}; FlushQueriesPtr flush_queries{nullptr}; @@ -156,6 +159,11 @@ class ExternalBackend : public Backend { return accelerator_name_; } + void IssueFirstTokenQuery() override { + if (backend_functions_.issue_first_token_query(backend_ptr_) != MLPERF_SUCCESS) { + LOG(FATAL) << "Error while inferencing model for first token"; + } + } // Run inference for a sample. void IssueQuery() override { if (backend_functions_.issue_query(backend_ptr_) != MLPERF_SUCCESS) { diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index df854627b..fc97e3722 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -439,7 +439,7 @@ int Main(int argc, char *argv[]) { batch_size); driver.RunMLPerfTest(mode, min_query_count, min_duration_ms / 1000.0, max_duration_ms / 1000.0, - single_stream_expected_latency_ns, output_dir); + single_stream_expected_latency_ns, output_dir, benchmark_id=="llm"); LOG(INFO) << "Accuracy: " << driver.ComputeAccuracyString(); return 0; } diff --git a/flutter/cpp/dataset.h b/flutter/cpp/dataset.h index a753ceadc..ebdd3fc3a 100644 --- a/flutter/cpp/dataset.h +++ b/flutter/cpp/dataset.h @@ -60,6 +60,10 @@ class Dataset : public ::mlperf::QuerySampleLibrary { virtual std::vector ProcessOutput( const int sample_idx, const std::vector& outputs) = 0; + // Should be called after ProcessOutput. + virtual int64_t GetOutputTokenCount( + const int sample_idx) {return 0;} + virtual bool HasAccuracy() { return false; } // ComputeAccuracy calculates the accuracy of the processed outputs. This diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 742237da8..34d586572 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -12,6 +12,7 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord/*, const st : sample_reader_(input_tfrecord), Dataset(backend) { std::cout << "MMLUT-DATASET: " << "Initializing with TFRecord " << input_tfrecord << " with sample size " << std::to_string(sample_reader_.Size()) << std::endl; // Load all TFRecord samples into memory + //TODO move to MmluGen::LoadSamplesToRam? for (size_t i = 0; i < sample_reader_.Size(); i++) { tensorflow::tstring record = sample_reader_.ReadRecord(i); tensorflow::Example example; @@ -26,12 +27,13 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord/*, const st std::cout << "MMLUT-DATASET: " << "Loading TFRecord Data index " << std::to_string(i) << " with answer {" << answer << "}" << std::endl; samples_.push_back(std::move(sample)); + sample_output_token_counts_.push_back(0); } //LoadSentencePieceProcessor(input_sppp); } void MmluGen::LoadSamplesToRam(const std::vector& samples) { - std::cout << "MMLUT-DATASET: " << "Loading Samples..." << std::endl; + std::cout << "MMLUT-DATASET: " << "Loading " << std::to_string(samples.size()) << " samples..." << std::endl; for (auto id : samples) { loaded_sample_ids_.insert(id); } @@ -55,6 +57,7 @@ std::vector MmluGen::GetData(int sample_idx) { std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vector& outputs) { if (sample_idx >= samples_.size() || outputs.empty()) return {0}; + sample_output_token_counts_[sample_idx] = reinterpret_cast*>(outputs[1])->size(); const char* prediction = reinterpret_cast(outputs[0]); char predicted_char = prediction[1]; // Assume second token is the answer because of whitespace (e.g., 'A', 'B', ...) std::cout << "MMLUT-DATASET: " << "Predicted answer: " << predicted_char << std::endl; @@ -69,6 +72,11 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vec return {static_cast(is_correct)}; } + +int64_t MmluGen::GetOutputTokenCount(const int sample_idx) { + return sample_output_token_counts_[sample_idx]; +} + bool MmluGen::HasAccuracy() { return true; } diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 9773d2338..7a2a96705 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -25,6 +25,8 @@ class MmluGen : public Dataset { size_t TotalSampleCount() override { return samples_.size(); } + size_t PerformanceSampleCount() override { return 1; } + void LoadSamplesToRam(const std::vector& samples) override; void UnloadSamplesFromRam(const std::vector& samples) override; @@ -33,6 +35,8 @@ class MmluGen : public Dataset { std::vector ProcessOutput(const int sample_idx, const std::vector& outputs) override; + int64_t GetOutputTokenCount(const int sample_idx) override; + bool HasAccuracy() override; float ComputeAccuracy() override; @@ -54,6 +58,7 @@ class MmluGen : public Dataset { }; std::vector> samples_; + std::vector sample_output_token_counts_; std::set loaded_sample_ids_; size_t correct_ = 0; diff --git a/flutter/cpp/mlperf_driver.cc b/flutter/cpp/mlperf_driver.cc index d49e08f13..12f94fc53 100644 --- a/flutter/cpp/mlperf_driver.cc +++ b/flutter/cpp/mlperf_driver.cc @@ -31,6 +31,7 @@ namespace mobile { void MlperfDriver::IssueQuery( const std::vector<::mlperf::QuerySample>& samples) { std::vector<::mlperf::QuerySampleResponse> responses; + std::vector<::mlperf::QuerySampleResponse> ft_responses; std::vector> response_data; if (scenario_ == "Offline") { @@ -67,15 +68,33 @@ void MlperfDriver::IssueQuery( ::mlperf::QuerySample sample = samples.at(idx); std::vector inputs = dataset_->GetData(sample.index); backend_->SetInputs(inputs); + + if (use_tokens_) { + ft_responses.clear(); + backend_->IssueFirstTokenQuery(); + ft_responses.push_back({sample.id, reinterpret_cast(nullptr), 0}); + ::mlperf::FirstTokenComplete(ft_responses.data(), ft_responses.size()); + } + backend_->IssueQuery(); + // Report to mlperf. std::vector outputs = backend_->GetPredictedOutputs(); response_data.push_back(dataset_->ProcessOutput(sample.index, outputs)); - responses.push_back( - {sample.id, - reinterpret_cast(response_data[idx].data()), - response_data[idx].size()}); + if (use_tokens_){ + responses.push_back( + {sample.id, + reinterpret_cast(response_data[idx].data()), + response_data[idx].size(), + dataset_->GetOutputTokenCount(sample.index)}); + } + else { + responses.push_back( + {sample.id, + reinterpret_cast(response_data[idx].data()), + response_data[idx].size()}); + } backend_->FlushQueries(); query_counter_ += 1; } @@ -86,7 +105,7 @@ void MlperfDriver::IssueQuery( void MlperfDriver::RunMLPerfTest(const std::string& mode, int min_query_count, double min_duration, double max_duration, int single_stream_expected_latency_ns, - const std::string& output_dir) { + const std::string& output_dir, bool use_tokens) { ::mlperf::LogSettings log_settings; log_settings.log_output.outdir = output_dir; log_settings.log_output.copy_summary_to_stdout = true; @@ -97,7 +116,12 @@ void MlperfDriver::RunMLPerfTest(const std::string& mode, int min_query_count, mlperf_settings.sample_index_rng_seed = 10688027786191513374UL; mlperf_settings.schedule_rng_seed = 14962580496156340209UL; - mlperf_settings.min_query_count = min_query_count; + //mlperf_settings.min_query_count = 1; + //mlperf_settings.max_query_count = 2; + //mlperf_settings.performance_sample_count_override = 5; + use_tokens_ = use_tokens; + mlperf_settings.use_token_latencies = use_tokens; + //mlperf_settings.server_target_qps = 0.1; mlperf_settings.mode = Str2TestMode(mode); mlperf_settings.min_duration_ms = static_cast(std::ceil(min_duration * 1000.0)); diff --git a/flutter/cpp/mlperf_driver.h b/flutter/cpp/mlperf_driver.h index 14491b486..91abe78c0 100644 --- a/flutter/cpp/mlperf_driver.h +++ b/flutter/cpp/mlperf_driver.h @@ -46,7 +46,7 @@ class MlperfDriver : public ::mlperf::SystemUnderTest { void RunMLPerfTest(const std::string& mode, int min_query_count, double min_duration, double max_duration, int single_stream_expected_latency_ns, - const std::string& output_dir); + const std::string& output_dir, bool use_tokens = false); // A human-readable string for logging purposes. const std::string& Name() override { return backend_->Name(); } @@ -77,6 +77,7 @@ class MlperfDriver : public ::mlperf::SystemUnderTest { std::string scenario_; int batch_; std::atomic query_counter_{0}; + bool use_tokens_; }; } // namespace mobile diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 5bbf31efb..be89f1158 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -27,12 +27,7 @@ limitations under the License. #include "flutter/cpp/c/type.h" #include "flutter/cpp/utils.h" #include "tensorflow/lite/c/common.h" -#if __ANDROID__ -#include -#include "tensorflow/lite/delegates/gpu/delegate.h" -#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" -#endif #include "tensorflow/core/platform/logging.h" #ifdef __cplusplus @@ -89,6 +84,10 @@ mlperf_backend_ptr_t LLMPipeline::backend_create(const char *model_path, mlperf_ return nullptr; } + if (!backend_data->end_token.empty()) { + backend_data->stop_token_id = backend_data->sp_processor->PieceToId((backend_data->end_token)); + } + return backend_data; } @@ -109,12 +108,12 @@ const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { return backend_data->name; } -// Run the inference for a sample. -mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +//TODO this needs to include all token processing until the model produces the first output token +mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - // Get Input Tensors for each of the runners. - // Shape: [Batch, Seq], Dtype: int32 + // Get Input Tensors for each of the runners. + // Shape: [Batch, Seq], Dtype: int32 TfLiteTensor* prefill_input = backend_data->prefill_runner->input_tensor("tokens"); // Shape: [Seq], Dtype: int32 TfLiteTensor* prefill_input_pos = backend_data->prefill_runner->input_tensor("input_pos"); @@ -139,43 +138,87 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); - // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle - int decode_steps = std::min(backend_data->max_output_tokens+decode_input_size, kv_cache_max_size - prefill_seq_size); - MINIMAL_CHECK(decode_steps > 0); + // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle + int decode_steps = decode_input_size; - std::vector output_tokens; - output_tokens.reserve(decode_steps - decode_input_size); + //backend_data->output_tokens.reserve(decode_steps - decode_input_size); int next_token = backend_data->prompt_tokens[prefill_seq_size - 1]; int next_position = prefill_seq_size - 1; int next_input_position = 0; // only used if we need to put input in the decode runner - for (int i = 0; i < decode_steps; ++i) { + for (int i = 0; i < decode_steps-1; ++i) { decode_input->data.i32[0] = next_token; decode_input_pos->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); - // if the input is larger than the maximum prefill size, the decode step takes the rest, without checking logits - if (decode_input_size > 0 && next_input_position < decode_input_size) { - next_token = backend_data->prompt_tokens[max_seq_size + next_input_position++]; + next_token = backend_data->prompt_tokens[max_seq_size + next_input_position++]; //LOG(INFO) << "position is " << std::to_string(next_input_position) << "/" << std::to_string(decode_seq_size) << std::endl; - } - else { - next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); - LOG(INFO) << backend_data->sp_processor->IdToPiece(next_token) << std::endl; - if (next_token == backend_data->stop_token_id) break; - output_tokens.push_back(next_token); - } next_position += 1; } - MINIMAL_CHECK(backend_data->sp_processor->Decode(output_tokens, &backend_data->output).ok()); + return MLPERF_SUCCESS; +} - LOG(INFO) << "Output: " << backend_data->output << std::endl; +// Run the inference for a sample. +mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + + // Get Input Tensors for each of the runners. + // Shape: [Batch, Seq], Dtype: int32 + TfLiteTensor* prefill_input = backend_data->prefill_runner->input_tensor("tokens"); + // Shape: [Seq], Dtype: int32 + TfLiteTensor* prefill_input_pos = backend_data->prefill_runner->input_tensor("input_pos"); + // Shape: [Batch, Seq], Dtype: int32 + TfLiteTensor* decode_input = backend_data->decode_runner->input_tensor("tokens"); + // Shape: [Seq], Dtype: int32 + TfLiteTensor* decode_input_pos = backend_data->decode_runner->input_tensor("input_pos"); + // shape: [Batch, kv_cache_max, num_query_groups, head_dim] + TfLiteTensor* kv_cache_k_0 = backend_data->decode_runner->input_tensor("kv_cache_k_0"); + + int kv_cache_max_size = kv_cache_k_0->dims->data[1]; + size_t input_size = backend_data->prompt_tokens.size(); + + //std::memset(prefill_input->data.i32, 0, prefill_input->bytes); + //std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes); + //for (int i = 0; i < prefill_seq_size - 1; ++i) { + // prefill_input->data.i32[i] = backend_data->prompt_tokens[i]; + // prefill_input_pos->data.i32[i] = i; + //} + + //MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); + + // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle + int decode_steps = std::min(backend_data->max_output_tokens, kv_cache_max_size - (static_cast(input_size) - 1)); + MINIMAL_CHECK(decode_steps > 0); + + //backend_data->output_tokens.reserve(decode_steps - decode_input_size); + int next_token = backend_data->prompt_tokens[input_size - 1]; + int next_position = input_size - 1; + for (int i = 0; i < decode_steps; ++i) { + decode_input->data.i32[0] = next_token; + decode_input_pos->data.i32[0] = next_position; + MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); + + next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); + //LOG(INFO) << backend_data->sp_processor->IdToPiece(next_token) << std::endl; + if (next_token == backend_data->stop_token_id) break; + backend_data->output_tokens.push_back(next_token); + next_position += 1; + } return MLPERF_SUCCESS; } // Flush the staged queries immediately. mlperf_status_t LLMPipeline::backend_flush_queries(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + + backend_data->prompt_tokens.clear(); + backend_data->output_tokens.clear(); + + for (auto& [_, vec]: backend_data->kv_cache) { + std::fill(vec.begin(), vec.end(), 0.0f); + } + return MLPERF_SUCCESS; } @@ -204,12 +247,6 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, backend_data->prompt_tokens.insert(backend_data->prompt_tokens.begin(), backend_data->sp_processor->PieceToId((backend_data->start_token))); } - // NOTE block below can be moved safely to backend_create - if (!backend_data->end_token.empty()) { - backend_data->stop_token_id = backend_data->sp_processor->PieceToId((backend_data->end_token)); - } - // --- - uint16_t effective_prefill_token_size = backend_data->prompt_tokens.size() - 1; //assuming max tokens is <16k backend_data->prefill_runner = GetPrefillRunner(backend_data->interpreter, effective_prefill_token_size, backend_data->kv_cache); @@ -220,7 +257,7 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, // Return the number of outputs for the model. int32_t LLMPipeline::backend_get_output_count(mlperf_backend_ptr_t backend_ptr) { - return 1; + return 2; // 0 is the output string, 1 is the output tokens } // Return the type of ith output. @@ -232,12 +269,17 @@ mlperf_data_t LLMPipeline::backend_get_output_type(mlperf_backend_ptr_t backend_ mlperf_status_t LLMPipeline::backend_get_output(mlperf_backend_ptr_t backend_ptr, uint32_t batch_index, int32_t i, void **data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - if (i == 0) { + if (i == 0){ + MINIMAL_CHECK(backend_data->sp_processor->Decode(backend_data->output_tokens, &backend_data->output).ok()); + LOG(INFO) << "Output: " << backend_data->output << std::endl; + *data = backend_data->output.data(); - return MLPERF_SUCCESS; } - - return MLPERF_FAILURE; + else if (i == 1) { + *data = &backend_data->output_tokens; + } + else return MLPERF_FAILURE; + return MLPERF_SUCCESS; } void LLMPipeline::backend_convert_inputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, uint8_t *data) {} @@ -287,9 +329,9 @@ kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter* interpreter) { TfLiteTensor* tensor = runner->input_tensor(k_cache_name.c_str()); size_t count = tensor->bytes / sizeof(float); kv_cache.emplace(k_cache_name, - std::vector>(count, 0.0)); + std::vector>(count, 0.0f)); kv_cache.emplace(v_cache_name, - std::vector>(count, 0.0)); + std::vector>(count, 0.0f)); } return kv_cache; diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 955cfc017..c79c28213 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -58,9 +58,9 @@ class AlignedAllocator { T* allocate(std::size_t n) { void* ptr; std::size_t size = n * sizeof(T); - std::size_t padding = tflite::kDefaultTensorAlignment - - (size % tflite::kDefaultTensorAlignment); - size += padding; + //std::size_t padding = tflite::kDefaultTensorAlignment - + // (size % tflite::kDefaultTensorAlignment); + //size += padding; int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); if (ret != 0) { return nullptr; @@ -86,12 +86,13 @@ struct LLMBackendData { kv_cache_t kv_cache; //std::string input_prompt; std::vector prompt_tokens; + std::vector output_tokens; + std::string output; uint8_t threads = 30; - uint32_t max_output_tokens = 2; + int max_output_tokens = 2; std::string start_token = ""; std::string end_token = ""; int stop_token_id = -1; - std::string output; LLMBackendData(){} @@ -133,6 +134,10 @@ class LLMPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_status_t backend_issue_first_token_query( + mlperf_backend_ptr_t backend_ptr) override; + mlperf_status_t backend_issue_query( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/pipeline.h b/mobile_back_tflite/cpp/backend_tflite/pipeline.h index 4ab1b4f1c..86747ffff 100644 --- a/mobile_back_tflite/cpp/backend_tflite/pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/pipeline.h @@ -40,6 +40,9 @@ class Pipeline { // Return the name of this backend. virtual const char *backend_name(mlperf_backend_ptr_t backend_ptr) = 0; + virtual mlperf_status_t backend_issue_first_token_query( + mlperf_backend_ptr_t backend_ptr) = 0; + // Run the inference for a sample. virtual mlperf_status_t backend_issue_query( mlperf_backend_ptr_t backend_ptr) = 0; @@ -87,4 +90,4 @@ class Pipeline { virtual void backend_release_buffer(void *p) = 0; }; -#endif // TFLITE_PIPELINE_H_ \ No newline at end of file +#endif // TFLITE_PIPELINE_H_ diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h index 70d447588..78f09a66a 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h @@ -37,6 +37,10 @@ class SingleModelPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_status_t backend_issue_first_token_query( + mlperf_backend_ptr_t backend_ptr) override {return MLPERF_FAILURE;} + mlperf_status_t backend_issue_query( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h index 17070a286..0b4046152 100644 --- a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h @@ -64,6 +64,10 @@ class StableDiffusionPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; + + mlperf_status_t backend_issue_first_token_query( + mlperf_backend_ptr_t backend_ptr) override {return MLPERF_FAILURE;} + mlperf_status_t backend_issue_query( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc index 64ce37147..4639b54cc 100644 --- a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc +++ b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc @@ -186,6 +186,10 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { reset_pipeline(); } +mlperf_status_t mlperf_backend_issue_first_token_query(mlperf_backend_ptr_t backend_ptr) { + return pipeline->backend_issue_first_token_query(backend_ptr); +} + // Run the inference for a sample. mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { return pipeline->backend_issue_query(backend_ptr); From 3c8b4f5cae6f617b4cfa457e74a420fc993ede2c Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 2 Sep 2025 00:42:23 +0300 Subject: [PATCH 11/74] fixed bugs in inference process, first token function now handles only input and issue_query only handles output tokens --- .../cpp/backend_tflite/llm_pipeline.cc | 57 ++++++++----------- 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index be89f1158..e67106708 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -108,12 +108,14 @@ const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { return backend_data->name; } -//TODO this needs to include all token processing until the model produces the first output token +// TODO chunked prefill support +// Run the prefill inference and at least 1 output token producing decode inference. +// This function exclusively handles the input tokens. mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - // Get Input Tensors for each of the runners. - // Shape: [Batch, Seq], Dtype: int32 + // Get Input Tensors for each of the runners. + // Shape: [Batch, Seq], Dtype: int32 TfLiteTensor* prefill_input = backend_data->prefill_runner->input_tensor("tokens"); // Shape: [Seq], Dtype: int32 TfLiteTensor* prefill_input_pos = backend_data->prefill_runner->input_tensor("input_pos"); @@ -127,38 +129,36 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_ int max_seq_size = prefill_input->dims->data[1]; int kv_cache_max_size = kv_cache_k_0->dims->data[1]; int prefill_seq_size = std::min(static_cast(backend_data->prompt_tokens.size()), max_seq_size); - int decode_input_size = std::max(static_cast(backend_data->prompt_tokens.size()) - max_seq_size, 0); // Making sure the decode seq isn't negative for later subtractions + bool prefill_overflow = static_cast(backend_data->prompt_tokens.size()) > max_seq_size; + int overflow_size = prefill_overflow ? static_cast(backend_data->prompt_tokens.size()) - max_seq_size : 0; + int prefill_amount = prefill_overflow ? prefill_seq_size : (prefill_seq_size - 1); + int decode_tokens = prefill_overflow ? overflow_size - 1 : 1; std::memset(prefill_input->data.i32, 0, prefill_input->bytes); std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes); - for (int i = 0; i < prefill_seq_size - 1; ++i) { + // If the prefill can fit the entire input, leave one token for decode, otherwise prefill as much of the input as possible. + for (int i = 0; i < prefill_amount; ++i) { prefill_input->data.i32[i] = backend_data->prompt_tokens[i]; prefill_input_pos->data.i32[i] = i; } MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); - // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle - int decode_steps = decode_input_size; - - //backend_data->output_tokens.reserve(decode_steps - decode_input_size); - int next_token = backend_data->prompt_tokens[prefill_seq_size - 1]; - int next_position = prefill_seq_size - 1; - int next_input_position = 0; // only used if we need to put input in the decode runner - for (int i = 0; i < decode_steps-1; ++i) { + // Run decode once if input fits inside prefill, otherwise decode the rest of the input one by one + int next_token = backend_data->prompt_tokens[prefill_amount]; + int next_position = prefill_amount; + for (int i = 0; i < decode_tokens; ++i) { decode_input->data.i32[0] = next_token; decode_input_pos->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); - // if the input is larger than the maximum prefill size, the decode step takes the rest, without checking logits - next_token = backend_data->prompt_tokens[max_seq_size + next_input_position++]; - //LOG(INFO) << "position is " << std::to_string(next_input_position) << "/" << std::to_string(decode_seq_size) << std::endl; - next_position += 1; + next_token = backend_data->prompt_tokens[++next_position]; } return MLPERF_SUCCESS; } -// Run the inference for a sample. +// Run the output token producing decode inference. +// This function exclusively takes output tokens to produce more output tokens. mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; @@ -177,29 +177,20 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt int kv_cache_max_size = kv_cache_k_0->dims->data[1]; size_t input_size = backend_data->prompt_tokens.size(); - //std::memset(prefill_input->data.i32, 0, prefill_input->bytes); - //std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes); - //for (int i = 0; i < prefill_seq_size - 1; ++i) { - // prefill_input->data.i32[i] = backend_data->prompt_tokens[i]; - // prefill_input_pos->data.i32[i] = i; - //} - - //MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); - // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle - int decode_steps = std::min(backend_data->max_output_tokens, kv_cache_max_size - (static_cast(input_size) - 1)); + int decode_steps = std::min(backend_data->max_output_tokens, kv_cache_max_size - static_cast(input_size)); MINIMAL_CHECK(decode_steps > 0); - //backend_data->output_tokens.reserve(decode_steps - decode_input_size); - int next_token = backend_data->prompt_tokens[input_size - 1]; - int next_position = input_size - 1; + backend_data->output_tokens.reserve(decode_steps); + int next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); + if (next_token == backend_data->stop_token_id) return MLPERF_SUCCESS; + backend_data->output_tokens.push_back(next_token); + int next_position = input_size; for (int i = 0; i < decode_steps; ++i) { decode_input->data.i32[0] = next_token; decode_input_pos->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); - next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); - //LOG(INFO) << backend_data->sp_processor->IdToPiece(next_token) << std::endl; if (next_token == backend_data->stop_token_id) break; backend_data->output_tokens.push_back(next_token); next_position += 1; From a03fbea216922f5399e12b4c4bd3989549cbce1d Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 8 Sep 2025 00:54:26 +0300 Subject: [PATCH 12/74] optimized tensor retrieval for inference + added check for input size vs KV cache size --- .../cpp/backend_tflite/llm_pipeline.cc | 87 ++++++++----------- .../cpp/backend_tflite/llm_pipeline.h | 62 ++++++++++--- 2 files changed, 88 insertions(+), 61 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index e67106708..fd4bd91cb 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -1,4 +1,4 @@ -/* Copyright 2020-2021 The MLPerf Authors. All Rights Reserved. +/* Copyright 2025 The MLPerf Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "llm_pipeline.h" #include @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/delegates/gpu/delegate.h" #ifdef __cplusplus extern "C" { @@ -108,38 +110,25 @@ const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { return backend_data->name; } -// TODO chunked prefill support // Run the prefill inference and at least 1 output token producing decode inference. // This function exclusively handles the input tokens. mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - // Get Input Tensors for each of the runners. - // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* prefill_input = backend_data->prefill_runner->input_tensor("tokens"); - // Shape: [Seq], Dtype: int32 - TfLiteTensor* prefill_input_pos = backend_data->prefill_runner->input_tensor("input_pos"); - // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* decode_input = backend_data->decode_runner->input_tensor("tokens"); - // Shape: [Seq], Dtype: int32 - TfLiteTensor* decode_input_pos = backend_data->decode_runner->input_tensor("input_pos"); - // shape: [Batch, kv_cache_max, num_query_groups, head_dim] - TfLiteTensor* kv_cache_k_0 = backend_data->decode_runner->input_tensor("kv_cache_k_0"); - - int max_seq_size = prefill_input->dims->data[1]; - int kv_cache_max_size = kv_cache_k_0->dims->data[1]; + int max_seq_size = backend_data->tensors.prefill_input()->dims->data[1]; + int kv_cache_max_size = backend_data->tensors.kv_cache_k_0()->dims->data[1]; int prefill_seq_size = std::min(static_cast(backend_data->prompt_tokens.size()), max_seq_size); bool prefill_overflow = static_cast(backend_data->prompt_tokens.size()) > max_seq_size; int overflow_size = prefill_overflow ? static_cast(backend_data->prompt_tokens.size()) - max_seq_size : 0; int prefill_amount = prefill_overflow ? prefill_seq_size : (prefill_seq_size - 1); int decode_tokens = prefill_overflow ? overflow_size - 1 : 1; - std::memset(prefill_input->data.i32, 0, prefill_input->bytes); - std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes); + std::memset(backend_data->tensors.prefill_input()->data.i32, 0, backend_data->tensors.prefill_input()->bytes); + std::memset(backend_data->tensors.prefill_input_pos()->data.i32, 0, backend_data->tensors.prefill_input_pos()->bytes); // If the prefill can fit the entire input, leave one token for decode, otherwise prefill as much of the input as possible. for (int i = 0; i < prefill_amount; ++i) { - prefill_input->data.i32[i] = backend_data->prompt_tokens[i]; - prefill_input_pos->data.i32[i] = i; + backend_data->tensors.prefill_input()->data.i32[i] = backend_data->prompt_tokens[i]; + backend_data->tensors.prefill_input_pos()->data.i32[i] = i; } MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); @@ -148,8 +137,8 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_ int next_token = backend_data->prompt_tokens[prefill_amount]; int next_position = prefill_amount; for (int i = 0; i < decode_tokens; ++i) { - decode_input->data.i32[0] = next_token; - decode_input_pos->data.i32[0] = next_position; + backend_data->tensors.decode_input()->data.i32[0] = next_token; + backend_data->tensors.decode_input_pos()->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); next_token = backend_data->prompt_tokens[++next_position]; } @@ -162,35 +151,25 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - // Get Input Tensors for each of the runners. - // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* prefill_input = backend_data->prefill_runner->input_tensor("tokens"); - // Shape: [Seq], Dtype: int32 - TfLiteTensor* prefill_input_pos = backend_data->prefill_runner->input_tensor("input_pos"); - // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* decode_input = backend_data->decode_runner->input_tensor("tokens"); - // Shape: [Seq], Dtype: int32 - TfLiteTensor* decode_input_pos = backend_data->decode_runner->input_tensor("input_pos"); - // shape: [Batch, kv_cache_max, num_query_groups, head_dim] - TfLiteTensor* kv_cache_k_0 = backend_data->decode_runner->input_tensor("kv_cache_k_0"); - - int kv_cache_max_size = kv_cache_k_0->dims->data[1]; + + int kv_cache_max_size = backend_data->tensors.kv_cache_k_0()->dims->data[1]; size_t input_size = backend_data->prompt_tokens.size(); - // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle - int decode_steps = std::min(backend_data->max_output_tokens, kv_cache_max_size - static_cast(input_size)); + // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle. + // take away 1 from max_output_tokens because backend_issue_first_token_query always generates the first output token. + int decode_steps = std::min(backend_data->max_output_tokens-1, kv_cache_max_size - static_cast(input_size)); MINIMAL_CHECK(decode_steps > 0); backend_data->output_tokens.reserve(decode_steps); - int next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); + int next_token = GreedySampler(backend_data->tensors.logits_output()); if (next_token == backend_data->stop_token_id) return MLPERF_SUCCESS; backend_data->output_tokens.push_back(next_token); int next_position = input_size; for (int i = 0; i < decode_steps; ++i) { - decode_input->data.i32[0] = next_token; - decode_input_pos->data.i32[0] = next_position; + backend_data->tensors.decode_input()->data.i32[0] = next_token; + backend_data->tensors.decode_input_pos()->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); - next_token = GreedySampler(backend_data->decode_runner->output_tensor("logits")); + next_token = GreedySampler(backend_data->tensors.logits_output()); if (next_token == backend_data->stop_token_id) break; backend_data->output_tokens.push_back(next_token); next_position += 1; @@ -201,15 +180,6 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt // Flush the staged queries immediately. mlperf_status_t LLMPipeline::backend_flush_queries(mlperf_backend_ptr_t backend_ptr) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - - backend_data->prompt_tokens.clear(); - backend_data->output_tokens.clear(); - - for (auto& [_, vec]: backend_data->kv_cache) { - std::fill(vec.begin(), vec.end(), 0.0f); - } - return MLPERF_SUCCESS; } @@ -230,8 +200,14 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, std::string prompt = std::string(static_cast(data)); + // Reset the tokens and kv caches from potential previous runs. + backend_data->prompt_tokens.clear(); + backend_data->output_tokens.clear(); + + for (auto& [_, vec]: backend_data->kv_cache) { + std::fill(vec.begin(), vec.end(), 0.0f); + } - LOG(INFO) << "Input: " << prompt << std::endl; MINIMAL_CHECK(backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens).ok()); if (!backend_data->start_token.empty()) { @@ -242,6 +218,15 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, backend_data->prefill_runner = GetPrefillRunner(backend_data->interpreter, effective_prefill_token_size, backend_data->kv_cache); + // Get the necessary tensor pointers for inference. + backend_data->tensors.get_tensors(backend_data->prefill_runner, backend_data->decode_runner); + + if (effective_prefill_token_size+1 > backend_data->tensors.kv_cache_k_0()->dims->data[1]) { + LOG(ERROR) << "Input size (" << std::to_string(effective_prefill_token_size+1) << ") exceeds KV cache limit (" << std::to_string(backend_data->tensors.kv_cache_k_0()->dims->data[1]) << ")." << std::endl; + return MLPERF_FAILURE; + } + + return MLPERF_SUCCESS; } diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index c79c28213..1c546a3d4 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -1,8 +1,11 @@ -/* Copyright 2024 The MLPerf Authors. All Rights Reserved. +/* Copyright 2025 The MLPerf Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -21,7 +24,6 @@ limitations under the License. #include "pipeline.h" #include "src/sentencepiece_processor.h" -#include "tensorflow/lite/experimental/genai/genai_ops.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/kernels/register.h" @@ -58,6 +60,7 @@ class AlignedAllocator { T* allocate(std::size_t n) { void* ptr; std::size_t size = n * sizeof(T); + // NOTE this part of the code from seems to be redundant //std::size_t padding = tflite::kDefaultTensorAlignment - // (size % tflite::kDefaultTensorAlignment); //size += padding; @@ -73,6 +76,49 @@ class AlignedAllocator { using kv_cache_t = std::map>>; +// A simple container for pointers to the tensors used during inference. +// The pointers here should not be managed or deleted by this struct. +struct LLMTensors { + + bool get_tensors (tflite::SignatureRunner *prefill_runner, tflite::SignatureRunner *decode_runner) { + prefill_input_ = prefill_runner->input_tensor("tokens"); + prefill_input_pos_ = prefill_runner->input_tensor("input_pos"); + decode_input_ = decode_runner->input_tensor("tokens"); + decode_input_pos_ = decode_runner->input_tensor("input_pos"); + logits_output_ = decode_runner->output_tensor("logits"); + kv_cache_k_0_ = decode_runner->input_tensor("kv_cache_k_0"); + + // Making sure none of the tensors are nullptr. + return prefill_input_ && prefill_input_pos_ && decode_input_ && decode_input_pos_ && logits_output_ && kv_cache_k_0_; + } + + LLMTensors(){} + + LLMTensors(const LLMTensors&) = delete; + LLMTensors& operator=(const LLMTensors&) = delete; + + TfLiteTensor* prefill_input() const {return prefill_input_;} + TfLiteTensor* prefill_input_pos() const {return prefill_input_pos_;} + TfLiteTensor* decode_input() const {return decode_input_;} + TfLiteTensor* decode_input_pos() const {return decode_input_pos_;} + const TfLiteTensor* logits_output() const {return logits_output_;} + TfLiteTensor* kv_cache_k_0() const {return kv_cache_k_0_;} + +private: + // Shape: [Batch, Seq], Dtype: int32 + TfLiteTensor* prefill_input_; + // Shape: [Seq], Dtype: int32 + TfLiteTensor* prefill_input_pos_; + // Shape: [Batch, Seq], Dtype: int32 + TfLiteTensor* decode_input_; + // Shape: [Seq], Dtype: int32 + TfLiteTensor* decode_input_pos_; + // Shape: [Seq], Dtype: float32 + const TfLiteTensor* logits_output_; + // shape: [Batch, kv_cache_max, num_query_groups, head_dim] + TfLiteTensor* kv_cache_k_0_; +}; + struct LLMBackendData { const char *name = "TFLite"; const char *vendor = "Google"; @@ -83,8 +129,8 @@ struct LLMBackendData { tflite::Interpreter *interpreter{}; tflite::SignatureRunner *prefill_runner{nullptr}; tflite::SignatureRunner *decode_runner{nullptr}; + LLMTensors tensors; kv_cache_t kv_cache; - //std::string input_prompt; std::vector prompt_tokens; std::vector output_tokens; std::string output; @@ -106,12 +152,6 @@ struct LLMBackendData { LLMBackendData(const LLMBackendData&) = delete; LLMBackendData& operator=(const LLMBackendData&) = delete; -// uint32_t real_batch_size = 1; -//std::unique_ptr executer; -// int32_t original_tensor_size = 0; -//#ifdef MTK_TFLITE_NEURON_BACKEND -// neuron_backend_ptr_t neuronBackendData{nullptr}; -//#endif }; // A simple pipeline which runs a single model. @@ -181,6 +221,8 @@ class LLMPipeline : public Pipeline { sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor(std::string path); int GreedySampler(const TfLiteTensor *logits); + + }; #endif // TFLITE_SINGLE_MODEL_PIPELINE_H_ From 69a630add9a8dfae9ce60a7412d50a4ae6b69d91 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 8 Sep 2025 00:55:39 +0300 Subject: [PATCH 13/74] clang-format --- .../cpp/backend_tflite/llm_pipeline.cc | 226 +++++++++++------- .../cpp/backend_tflite/llm_pipeline.h | 115 ++++----- 2 files changed, 200 insertions(+), 141 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index fd4bd91cb..4fbd09f87 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -27,9 +27,8 @@ limitations under the License. #include "flutter/cpp/c/type.h" #include "flutter/cpp/utils.h" -#include "tensorflow/lite/c/common.h" - #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/gpu/delegate.h" #ifdef __cplusplus @@ -47,7 +46,9 @@ void LLMPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { // Create a new backend and return the pointer to it. // TODO add eos and bos tokens as config parameters -mlperf_backend_ptr_t LLMPipeline::backend_create(const char *model_path, mlperf_backend_configuration_t *configs, const char *native_lib_path) { +mlperf_backend_ptr_t LLMPipeline::backend_create( + const char *model_path, mlperf_backend_configuration_t *configs, + const char *native_lib_path) { // Verify only one instance of the backend exists at any time if (backendExists) { LOG(ERROR) << "Only one backend instance should exist at a time"; @@ -57,17 +58,20 @@ mlperf_backend_ptr_t LLMPipeline::backend_create(const char *model_path, mlperf_ LLMBackendData *backend_data = new LLMBackendData(); // sentencePiece Processor Path - std::string sppp = mlperf::mobile::GetConfigValue(configs, "sentencepiece_processor_path", std::string("")); + std::string sppp = mlperf::mobile::GetConfigValue( + configs, "sentencepiece_processor_path", std::string("")); // Load the model. - backend_data->model = tflite::FlatBufferModel::BuildFromFile(model_path).release(); + backend_data->model = + tflite::FlatBufferModel::BuildFromFile(model_path).release(); if (!backend_data->model) { LOG(ERROR) << "Failed to load model: " << model_path; backend_delete(backend_data); return nullptr; } - backend_data->interpreter = BuildInterpreter(backend_data->model, backend_data->threads); + backend_data->interpreter = + BuildInterpreter(backend_data->model, backend_data->threads); if (!backend_data->interpreter) { LOG(ERROR) << "Failed to load interpreter"; backend_delete(backend_data); @@ -75,19 +79,21 @@ mlperf_backend_ptr_t LLMPipeline::backend_create(const char *model_path, mlperf_ } backend_data->kv_cache = BuildKVCache(backend_data->interpreter); - //TODO kv_cache check + // TODO kv_cache check - backend_data->decode_runner = GetDecodeRunner(backend_data->interpreter, backend_data->kv_cache); + backend_data->decode_runner = + GetDecodeRunner(backend_data->interpreter, backend_data->kv_cache); backend_data->sp_processor = LoadSentencePieceProcessor(sppp); if (!backend_data->sp_processor) { LOG(ERROR) << "Failed to load sentencepiece processor: " << sppp; - backend_delete(backend_data); + backend_delete(backend_data); return nullptr; } if (!backend_data->end_token.empty()) { - backend_data->stop_token_id = backend_data->sp_processor->PieceToId((backend_data->end_token)); + backend_data->stop_token_id = + backend_data->sp_processor->PieceToId((backend_data->end_token)); } return backend_data; @@ -99,7 +105,8 @@ const char *LLMPipeline::backend_vendor_name(mlperf_backend_ptr_t backend_ptr) { return backend_data->vendor; } -const char *LLMPipeline::backend_accelerator_name(mlperf_backend_ptr_t backend_ptr) { +const char *LLMPipeline::backend_accelerator_name( + mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; return backend_data->accelerator; } @@ -110,30 +117,42 @@ const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { return backend_data->name; } -// Run the prefill inference and at least 1 output token producing decode inference. -// This function exclusively handles the input tokens. -mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_t backend_ptr) { +// Run the prefill inference and at least 1 output token producing decode +// inference. This function exclusively handles the input tokens. +mlperf_status_t LLMPipeline::backend_issue_first_token_query( + mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; int max_seq_size = backend_data->tensors.prefill_input()->dims->data[1]; int kv_cache_max_size = backend_data->tensors.kv_cache_k_0()->dims->data[1]; - int prefill_seq_size = std::min(static_cast(backend_data->prompt_tokens.size()), max_seq_size); - bool prefill_overflow = static_cast(backend_data->prompt_tokens.size()) > max_seq_size; - int overflow_size = prefill_overflow ? static_cast(backend_data->prompt_tokens.size()) - max_seq_size : 0; - int prefill_amount = prefill_overflow ? prefill_seq_size : (prefill_seq_size - 1); + int prefill_seq_size = std::min( + static_cast(backend_data->prompt_tokens.size()), max_seq_size); + bool prefill_overflow = + static_cast(backend_data->prompt_tokens.size()) > max_seq_size; + int overflow_size = + prefill_overflow + ? static_cast(backend_data->prompt_tokens.size()) - max_seq_size + : 0; + int prefill_amount = + prefill_overflow ? prefill_seq_size : (prefill_seq_size - 1); int decode_tokens = prefill_overflow ? overflow_size - 1 : 1; - std::memset(backend_data->tensors.prefill_input()->data.i32, 0, backend_data->tensors.prefill_input()->bytes); - std::memset(backend_data->tensors.prefill_input_pos()->data.i32, 0, backend_data->tensors.prefill_input_pos()->bytes); - // If the prefill can fit the entire input, leave one token for decode, otherwise prefill as much of the input as possible. + std::memset(backend_data->tensors.prefill_input()->data.i32, 0, + backend_data->tensors.prefill_input()->bytes); + std::memset(backend_data->tensors.prefill_input_pos()->data.i32, 0, + backend_data->tensors.prefill_input_pos()->bytes); + // If the prefill can fit the entire input, leave one token for decode, + // otherwise prefill as much of the input as possible. for (int i = 0; i < prefill_amount; ++i) { - backend_data->tensors.prefill_input()->data.i32[i] = backend_data->prompt_tokens[i]; + backend_data->tensors.prefill_input()->data.i32[i] = + backend_data->prompt_tokens[i]; backend_data->tensors.prefill_input_pos()->data.i32[i] = i; } MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); - // Run decode once if input fits inside prefill, otherwise decode the rest of the input one by one + // Run decode once if input fits inside prefill, otherwise decode the rest of + // the input one by one int next_token = backend_data->prompt_tokens[prefill_amount]; int next_position = prefill_amount; for (int i = 0; i < decode_tokens; ++i) { @@ -148,16 +167,19 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query(mlperf_backend_ptr_ // Run the output token producing decode inference. // This function exclusively takes output tokens to produce more output tokens. -mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t LLMPipeline::backend_issue_query( + mlperf_backend_ptr_t backend_ptr) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - int kv_cache_max_size = backend_data->tensors.kv_cache_k_0()->dims->data[1]; size_t input_size = backend_data->prompt_tokens.size(); - // Use a manual number for maximum tokens to generate as long as it's not larger than the KV cache can handle. - // take away 1 from max_output_tokens because backend_issue_first_token_query always generates the first output token. - int decode_steps = std::min(backend_data->max_output_tokens-1, kv_cache_max_size - static_cast(input_size)); + // Use a manual number for maximum tokens to generate as long as it's not + // larger than the KV cache can handle. take away 1 from max_output_tokens + // because backend_issue_first_token_query always generates the first output + // token. + int decode_steps = std::min(backend_data->max_output_tokens - 1, + kv_cache_max_size - static_cast(input_size)); MINIMAL_CHECK(decode_steps > 0); backend_data->output_tokens.reserve(decode_steps); @@ -179,98 +201,122 @@ mlperf_status_t LLMPipeline::backend_issue_query(mlperf_backend_ptr_t backend_pt } // Flush the staged queries immediately. -mlperf_status_t LLMPipeline::backend_flush_queries(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t LLMPipeline::backend_flush_queries( + mlperf_backend_ptr_t backend_ptr) { return MLPERF_SUCCESS; } // Return the number of inputs of the model. -// Only 1 input needs to be provided, which is the tokens themselves, the other inputs are handled by the pipeline +// Only 1 input needs to be provided, which is the tokens themselves, the other +// inputs are handled by the pipeline int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { return 1; } // Return the type of the ith input. -mlperf_data_t LLMPipeline::backend_get_input_type(mlperf_backend_ptr_t backend_ptr, int32_t i) { +mlperf_data_t LLMPipeline::backend_get_input_type( + mlperf_backend_ptr_t backend_ptr, int32_t i) { return mlperf_data_t{mlperf_data_t::Int32, 0}; } // Set the data for ith input. -mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, int32_t batch_index, int32_t i, void *data) { +mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, + int32_t batch_index, int32_t i, + void *data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - std::string prompt = std::string(static_cast(data)); + std::string prompt = std::string(static_cast(data)); // Reset the tokens and kv caches from potential previous runs. backend_data->prompt_tokens.clear(); backend_data->output_tokens.clear(); - for (auto& [_, vec]: backend_data->kv_cache) { + for (auto &[_, vec] : backend_data->kv_cache) { std::fill(vec.begin(), vec.end(), 0.0f); } - MINIMAL_CHECK(backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens).ok()); + MINIMAL_CHECK( + backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens) + .ok()); if (!backend_data->start_token.empty()) { - backend_data->prompt_tokens.insert(backend_data->prompt_tokens.begin(), backend_data->sp_processor->PieceToId((backend_data->start_token))); + backend_data->prompt_tokens.insert( + backend_data->prompt_tokens.begin(), + backend_data->sp_processor->PieceToId((backend_data->start_token))); } - uint16_t effective_prefill_token_size = backend_data->prompt_tokens.size() - 1; //assuming max tokens is <16k + uint16_t effective_prefill_token_size = + backend_data->prompt_tokens.size() - 1; // assuming max tokens is <16k - backend_data->prefill_runner = GetPrefillRunner(backend_data->interpreter, effective_prefill_token_size, backend_data->kv_cache); + backend_data->prefill_runner = + GetPrefillRunner(backend_data->interpreter, effective_prefill_token_size, + backend_data->kv_cache); // Get the necessary tensor pointers for inference. - backend_data->tensors.get_tensors(backend_data->prefill_runner, backend_data->decode_runner); - - if (effective_prefill_token_size+1 > backend_data->tensors.kv_cache_k_0()->dims->data[1]) { - LOG(ERROR) << "Input size (" << std::to_string(effective_prefill_token_size+1) << ") exceeds KV cache limit (" << std::to_string(backend_data->tensors.kv_cache_k_0()->dims->data[1]) << ")." << std::endl; + backend_data->tensors.get_tensors(backend_data->prefill_runner, + backend_data->decode_runner); + + if (effective_prefill_token_size + 1 > + backend_data->tensors.kv_cache_k_0()->dims->data[1]) { + LOG(ERROR) << "Input size (" + << std::to_string(effective_prefill_token_size + 1) + << ") exceeds KV cache limit (" + << std::to_string( + backend_data->tensors.kv_cache_k_0()->dims->data[1]) + << ")." << std::endl; return MLPERF_FAILURE; } - - return MLPERF_SUCCESS; } // Return the number of outputs for the model. -int32_t LLMPipeline::backend_get_output_count(mlperf_backend_ptr_t backend_ptr) { - return 2; // 0 is the output string, 1 is the output tokens +int32_t LLMPipeline::backend_get_output_count( + mlperf_backend_ptr_t backend_ptr) { + return 2; // 0 is the output string, 1 is the output tokens } // Return the type of ith output. -mlperf_data_t LLMPipeline::backend_get_output_type(mlperf_backend_ptr_t backend_ptr, int32_t i) { +mlperf_data_t LLMPipeline::backend_get_output_type( + mlperf_backend_ptr_t backend_ptr, int32_t i) { return mlperf_data_t{mlperf_data_t::Float32, 0}; } // Get the data from ith output. -mlperf_status_t LLMPipeline::backend_get_output(mlperf_backend_ptr_t backend_ptr, uint32_t batch_index, int32_t i, void **data) { +mlperf_status_t LLMPipeline::backend_get_output( + mlperf_backend_ptr_t backend_ptr, uint32_t batch_index, int32_t i, + void **data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - if (i == 0){ - MINIMAL_CHECK(backend_data->sp_processor->Decode(backend_data->output_tokens, &backend_data->output).ok()); + if (i == 0) { + MINIMAL_CHECK( + backend_data->sp_processor + ->Decode(backend_data->output_tokens, &backend_data->output) + .ok()); LOG(INFO) << "Output: " << backend_data->output << std::endl; *data = backend_data->output.data(); - } - else if (i == 1) { + } else if (i == 1) { *data = &backend_data->output_tokens; - } - else return MLPERF_FAILURE; + } else + return MLPERF_FAILURE; return MLPERF_SUCCESS; } -void LLMPipeline::backend_convert_inputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, uint8_t *data) {} +void LLMPipeline::backend_convert_inputs(mlperf_backend_ptr_t backend_ptr, + int bytes, int width, int height, + uint8_t *data) {} -void LLMPipeline::backend_convert_outputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, uint8_t *data) {} +void LLMPipeline::backend_convert_outputs(mlperf_backend_ptr_t backend_ptr, + int bytes, int width, int height, + uint8_t *data) {} -void *LLMPipeline::backend_get_buffer(size_t n) { - return ::operator new(n); -} +void *LLMPipeline::backend_get_buffer(size_t n) { return ::operator new(n); } -void LLMPipeline::backend_release_buffer(void *p) { - ::operator delete(p); -} +void LLMPipeline::backend_release_buffer(void *p) { ::operator delete(p); } -tflite::Interpreter *LLMPipeline::BuildInterpreter(tflite::FlatBufferModel *model, int num_threads) { +tflite::Interpreter *LLMPipeline::BuildInterpreter( + tflite::FlatBufferModel *model, int num_threads) { tflite::ops::builtin::BuiltinOpResolver resolver; // NOTE: We need to manually register optimized OPs for KV-cache and // Scaled Dot Product Attention (SDPA). @@ -285,8 +331,8 @@ tflite::Interpreter *LLMPipeline::BuildInterpreter(tflite::FlatBufferModel *mode return interpreter.release(); } -kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter* interpreter) { - tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); +kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter *interpreter) { + tflite::SignatureRunner *runner = interpreter->GetSignatureRunner("decode"); if (runner == nullptr) { return {}; } @@ -302,7 +348,7 @@ kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter* interpreter) { std::string k_cache_name = "kv_cache_k_" + std::to_string(i); std::string v_cache_name = "kv_cache_v_" + std::to_string(i); // We are assuming K and V tensors are of the same shape. - TfLiteTensor* tensor = runner->input_tensor(k_cache_name.c_str()); + TfLiteTensor *tensor = runner->input_tensor(k_cache_name.c_str()); size_t count = tensor->bytes / sizeof(float); kv_cache.emplace(k_cache_name, std::vector>(count, 0.0f)); @@ -313,56 +359,68 @@ kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter* interpreter) { return kv_cache; } -void LLMPipeline::PrepareRunner(tflite::SignatureRunner* runner, kv_cache_t& kv_cache) { - for (auto& [name, cache] : kv_cache) { - TfLiteCustomAllocation allocation = {.data = static_cast(cache.data()), .bytes = cache.size() * sizeof(float)}; +void LLMPipeline::PrepareRunner(tflite::SignatureRunner *runner, + kv_cache_t &kv_cache) { + for (auto &[name, cache] : kv_cache) { + TfLiteCustomAllocation allocation = { + .data = static_cast(cache.data()), + .bytes = cache.size() * sizeof(float)}; // Both input and output tensors are set to the same buffer. Not all // delegates support this in-place update. For those cases, we need to do // a ping-pong buffer and update the pointers between inference calls. - MINIMAL_CHECK_VOID(runner->SetCustomAllocationForInputTensor(name.c_str(), allocation) == kTfLiteOk); - MINIMAL_CHECK_VOID(runner->SetCustomAllocationForOutputTensor(name.c_str(), allocation) == kTfLiteOk); + MINIMAL_CHECK_VOID(runner->SetCustomAllocationForInputTensor( + name.c_str(), allocation) == kTfLiteOk); + MINIMAL_CHECK_VOID(runner->SetCustomAllocationForOutputTensor( + name.c_str(), allocation) == kTfLiteOk); } MINIMAL_CHECK_VOID(runner->AllocateTensors() == kTfLiteOk); } -tflite::SignatureRunner *LLMPipeline::GetPrefillRunner(tflite::Interpreter* interpreter, std::size_t num_input_tokens, kv_cache_t& kv_cache) { +tflite::SignatureRunner *LLMPipeline::GetPrefillRunner( + tflite::Interpreter *interpreter, std::size_t num_input_tokens, + kv_cache_t &kv_cache) { // Find the prefill signature length that best matches the input token size. - tflite::SignatureRunner* runner = nullptr; - //int best_seq_size = -1; + tflite::SignatureRunner *runner = nullptr; + // int best_seq_size = -1; size_t delta = std::numeric_limits::max(); size_t max_prefill_size = 0; std::string max_prefill_key = std::string(""); - for (const std::string* key : interpreter->signature_keys()) { + for (const std::string *key : interpreter->signature_keys()) { if (key->find("prefill") == std::string::npos) continue; - TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())->input_tensor("input_pos"); + TfLiteTensor *input_pos = interpreter->GetSignatureRunner(key->c_str()) + ->input_tensor("input_pos"); // The expected shape for input position is [Seq]. size_t seq_size = input_pos->dims->data[0]; - //TODO this could be else maybe? + // TODO this could be else maybe? if (seq_size > max_prefill_size) { max_prefill_size = seq_size; max_prefill_key = std::string(key->c_str()); } if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) { runner = interpreter->GetSignatureRunner(key->c_str()); - //best_seq_size = seq_size; + // best_seq_size = seq_size; delta = seq_size - num_input_tokens; } } - //fallback to maximum possible size if a runner is not found (most likely because the seq_size is larger than max_prefill_size) - if (!runner && max_prefill_key != "") runner = interpreter->GetSignatureRunner(max_prefill_key.c_str()); + // fallback to maximum possible size if a runner is not found (most likely + // because the seq_size is larger than max_prefill_size) + if (!runner && max_prefill_key != "") + runner = interpreter->GetSignatureRunner(max_prefill_key.c_str()); MINIMAL_CHECK_PTR(runner != nullptr); PrepareRunner(runner, kv_cache); return runner; } -tflite::SignatureRunner *LLMPipeline::GetDecodeRunner(tflite::Interpreter* interpreter, kv_cache_t& kv_cache) { - tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); +tflite::SignatureRunner *LLMPipeline::GetDecodeRunner( + tflite::Interpreter *interpreter, kv_cache_t &kv_cache) { + tflite::SignatureRunner *runner = interpreter->GetSignatureRunner("decode"); MINIMAL_CHECK_PTR(runner != nullptr); PrepareRunner(runner, kv_cache); return runner; } -sentencepiece::SentencePieceProcessor *LLMPipeline::LoadSentencePieceProcessor(std::string path) { +sentencepiece::SentencePieceProcessor *LLMPipeline::LoadSentencePieceProcessor( + std::string path) { std::ifstream input(path, std::ios::binary); std::string serialized_proto = std::string( std::istreambuf_iterator(input), std::istreambuf_iterator()); @@ -372,7 +430,7 @@ sentencepiece::SentencePieceProcessor *LLMPipeline::LoadSentencePieceProcessor(s } // A basic greedy sampler (equivalent to argmax). -int LLMPipeline::GreedySampler(const TfLiteTensor* logits) { +int LLMPipeline::GreedySampler(const TfLiteTensor *logits) { float max_value = -std::numeric_limits::infinity(); int max_index = 0; // logits shape: [Batch, Seq, Vocab], Dtype: float diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 1c546a3d4..d991c5d1a 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -22,31 +22,29 @@ limitations under the License. #include "flutter/cpp/c/type.h" #include "pipeline.h" - #include "src/sentencepiece_processor.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/model_builder.h" #include "tensorflow/lite/signature_runner.h" -#include "tensorflow/core/platform/logging.h" - -#define MINIMAL_CHECK(x) \ -if (!(x)) { \ - LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ - return MLPERF_FAILURE; \ -} -#define MINIMAL_CHECK_PTR(x) \ -if (!(x)) { \ - LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ - return nullptr; \ -} -#define MINIMAL_CHECK_VOID(x) \ -if (!(x)) { \ - LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ - return; \ -} +#define MINIMAL_CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return MLPERF_FAILURE; \ + } +#define MINIMAL_CHECK_PTR(x) \ + if (!(x)) { \ + LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return nullptr; \ + } +#define MINIMAL_CHECK_VOID(x) \ + if (!(x)) { \ + LOG(ERROR) << "Error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return; \ + } // TF Lite requires all buffers (including external buffers used for KV cache // here) be `tflite::kDefaultTensorAlignment` aligned. To ensure that, we use @@ -57,30 +55,31 @@ class AlignedAllocator { public: using value_type = T; - T* allocate(std::size_t n) { - void* ptr; + T *allocate(std::size_t n) { + void *ptr; std::size_t size = n * sizeof(T); // NOTE this part of the code from seems to be redundant - //std::size_t padding = tflite::kDefaultTensorAlignment - + // std::size_t padding = tflite::kDefaultTensorAlignment - // (size % tflite::kDefaultTensorAlignment); - //size += padding; + // size += padding; int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); if (ret != 0) { return nullptr; } - return static_cast(ptr); + return static_cast(ptr); }; - void deallocate(T* ptr, std::size_t n) { free(ptr); } + void deallocate(T *ptr, std::size_t n) { free(ptr); } }; -using kv_cache_t = std::map>>; +using kv_cache_t = + std::map>>; // A simple container for pointers to the tensors used during inference. // The pointers here should not be managed or deleted by this struct. struct LLMTensors { - - bool get_tensors (tflite::SignatureRunner *prefill_runner, tflite::SignatureRunner *decode_runner) { + bool get_tensors(tflite::SignatureRunner *prefill_runner, + tflite::SignatureRunner *decode_runner) { prefill_input_ = prefill_runner->input_tensor("tokens"); prefill_input_pos_ = prefill_runner->input_tensor("input_pos"); decode_input_ = decode_runner->input_tensor("tokens"); @@ -89,34 +88,35 @@ struct LLMTensors { kv_cache_k_0_ = decode_runner->input_tensor("kv_cache_k_0"); // Making sure none of the tensors are nullptr. - return prefill_input_ && prefill_input_pos_ && decode_input_ && decode_input_pos_ && logits_output_ && kv_cache_k_0_; + return prefill_input_ && prefill_input_pos_ && decode_input_ && + decode_input_pos_ && logits_output_ && kv_cache_k_0_; } - LLMTensors(){} + LLMTensors() {} - LLMTensors(const LLMTensors&) = delete; - LLMTensors& operator=(const LLMTensors&) = delete; + LLMTensors(const LLMTensors &) = delete; + LLMTensors &operator=(const LLMTensors &) = delete; - TfLiteTensor* prefill_input() const {return prefill_input_;} - TfLiteTensor* prefill_input_pos() const {return prefill_input_pos_;} - TfLiteTensor* decode_input() const {return decode_input_;} - TfLiteTensor* decode_input_pos() const {return decode_input_pos_;} - const TfLiteTensor* logits_output() const {return logits_output_;} - TfLiteTensor* kv_cache_k_0() const {return kv_cache_k_0_;} + TfLiteTensor *prefill_input() const { return prefill_input_; } + TfLiteTensor *prefill_input_pos() const { return prefill_input_pos_; } + TfLiteTensor *decode_input() const { return decode_input_; } + TfLiteTensor *decode_input_pos() const { return decode_input_pos_; } + const TfLiteTensor *logits_output() const { return logits_output_; } + TfLiteTensor *kv_cache_k_0() const { return kv_cache_k_0_; } -private: + private: // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* prefill_input_; + TfLiteTensor *prefill_input_; // Shape: [Seq], Dtype: int32 - TfLiteTensor* prefill_input_pos_; + TfLiteTensor *prefill_input_pos_; // Shape: [Batch, Seq], Dtype: int32 - TfLiteTensor* decode_input_; + TfLiteTensor *decode_input_; // Shape: [Seq], Dtype: int32 - TfLiteTensor* decode_input_pos_; + TfLiteTensor *decode_input_pos_; // Shape: [Seq], Dtype: float32 - const TfLiteTensor* logits_output_; + const TfLiteTensor *logits_output_; // shape: [Batch, kv_cache_max, num_query_groups, head_dim] - TfLiteTensor* kv_cache_k_0_; + TfLiteTensor *kv_cache_k_0_; }; struct LLMBackendData { @@ -125,7 +125,8 @@ struct LLMBackendData { const char *accelerator = "CPU"; tflite::FlatBufferModel *model{nullptr}; sentencepiece::SentencePieceProcessor *sp_processor{nullptr}; - //TfLiteInterpreterOptions *options{}; TODO use this to allow different delegates other than CPU? + // TfLiteInterpreterOptions *options{}; TODO use this to allow different + // delegates other than CPU? tflite::Interpreter *interpreter{}; tflite::SignatureRunner *prefill_runner{nullptr}; tflite::SignatureRunner *decode_runner{nullptr}; @@ -140,7 +141,7 @@ struct LLMBackendData { std::string end_token = ""; int stop_token_id = -1; - LLMBackendData(){} + LLMBackendData() {} ~LLMBackendData() { // Runners are owned by interpreter and therefore don't need to be deleted @@ -149,9 +150,8 @@ struct LLMBackendData { delete model; } - LLMBackendData(const LLMBackendData&) = delete; - LLMBackendData& operator=(const LLMBackendData&) = delete; - + LLMBackendData(const LLMBackendData &) = delete; + LLMBackendData &operator=(const LLMBackendData &) = delete; }; // A simple pipeline which runs a single model. @@ -174,7 +174,6 @@ class LLMPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_first_token_query( mlperf_backend_ptr_t backend_ptr) override; @@ -213,16 +212,18 @@ class LLMPipeline : public Pipeline { void backend_release_buffer(void *p) override; private: - tflite::Interpreter *BuildInterpreter(tflite::FlatBufferModel *model, int num_threads); + tflite::Interpreter *BuildInterpreter(tflite::FlatBufferModel *model, + int num_threads); kv_cache_t BuildKVCache(tflite::Interpreter *interpreter); void PrepareRunner(tflite::SignatureRunner *runner, kv_cache_t &kv_cache); - tflite::SignatureRunner *GetPrefillRunner(tflite::Interpreter *interpreter, std::size_t num_input_tokens, kv_cache_t &kv_cache); - tflite::SignatureRunner *GetDecodeRunner(tflite::Interpreter *interpreter, kv_cache_t &kv_cache); - sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor(std::string path); + tflite::SignatureRunner *GetPrefillRunner(tflite::Interpreter *interpreter, + std::size_t num_input_tokens, + kv_cache_t &kv_cache); + tflite::SignatureRunner *GetDecodeRunner(tflite::Interpreter *interpreter, + kv_cache_t &kv_cache); + sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor( + std::string path); int GreedySampler(const TfLiteTensor *logits); - - - }; #endif // TFLITE_SINGLE_MODEL_PIPELINE_H_ From 816f282e178ef0c4e744ae46297a05bb2a97956c Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 8 Sep 2025 08:12:48 +0300 Subject: [PATCH 14/74] mmlu dataset cleanup and formatting --- .bazelrc | 3 +- flutter/cpp/datasets/mmlu_gen.cc | 52 ++++++++++++++------------------ flutter/cpp/datasets/mmlu_gen.h | 13 +++----- 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/.bazelrc b/.bazelrc index 2b807cc72..f54b7b828 100644 --- a/.bazelrc +++ b/.bazelrc @@ -54,7 +54,7 @@ build:linux_x86_64 --cpu=k8 # Not required, but enables the proper SSE/MMX instructions per CPU build:linux_x86_64 --copt=-march=native -# These are neccessary because the compiler that bazel 6.3 uses doesn't support VNNI +# These may be neccessary depending on CPU instruction support #build:linux_x86_64 --define=xnn_enable_avx=false #build:linux_x86_64 --define=xnn_enable_avx2=false #build:linux_x86_64 --define=xnn_enable_avx512=false @@ -63,6 +63,7 @@ build:linux_x86_64 --define=xnn_enable_avx512fp16=false build:linux_x86_64 --define=xnn_enable_avxvnniint8=false #build:linux_x86_64 --define=xnn_enable_vnni=false + # Optional, enable for debugging or compilation errors #build:linux_x86_64 --action_env=CC=gcc #build:linux_x86_64 --action_env=CXX=g++ diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 34d586572..3cd8c22b6 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -1,85 +1,83 @@ #include "flutter/cpp/datasets/mmlu_gen.h" -#include "tensorflow/core/example/example.pb.h" -#include "tensorflow/core/example/feature_util.h" #include #include +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature_util.h" + namespace mlperf { namespace mobile { -MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord/*, const std::string& input_sppp*/) +MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) : sample_reader_(input_tfrecord), Dataset(backend) { - std::cout << "MMLUT-DATASET: " << "Initializing with TFRecord " << input_tfrecord << " with sample size " << std::to_string(sample_reader_.Size()) << std::endl; // Load all TFRecord samples into memory - //TODO move to MmluGen::LoadSamplesToRam? + // NOTE this can be moved to LoadSamplesToRam, but will cause delays between + // queries due to IO reads happening between queries for (size_t i = 0; i < sample_reader_.Size(); i++) { tensorflow::tstring record = sample_reader_.ReadRecord(i); tensorflow::Example example; example.ParseFromString(record); - std::string input = tensorflow::GetFeatureValues("input", example).Get(0); - std::string answer = tensorflow::GetFeatureValues("answer", example).Get(0); + std::string input = + tensorflow::GetFeatureValues("input", example).Get(0); + std::string answer = + tensorflow::GetFeatureValues("answer", example).Get(0); auto sample = std::make_unique(); sample->input = input; sample->answer = answer; - std::cout << "MMLUT-DATASET: " << "Loading TFRecord Data index " << std::to_string(i) << " with answer {" << answer << "}" << std::endl; - samples_.push_back(std::move(sample)); sample_output_token_counts_.push_back(0); } - //LoadSentencePieceProcessor(input_sppp); } void MmluGen::LoadSamplesToRam(const std::vector& samples) { - std::cout << "MMLUT-DATASET: " << "Loading " << std::to_string(samples.size()) << " samples..." << std::endl; for (auto id : samples) { loaded_sample_ids_.insert(id); } } -void MmluGen::UnloadSamplesFromRam(const std::vector& samples) { +void MmluGen::UnloadSamplesFromRam( + const std::vector& samples) { for (auto id : samples) { loaded_sample_ids_.erase(id); } } std::vector MmluGen::GetData(int sample_idx) { - std::cout << "MMLUT-DATASET: " << "Getting data at index " << std::to_string(sample_idx) << " (Answer is " << samples_[sample_idx]->answer << ")" << std::endl; std::vector data; if (sample_idx < samples_.size()) { - data.push_back(reinterpret_cast(const_cast(samples_[sample_idx]->input.c_str()))); + data.push_back(reinterpret_cast( + const_cast(samples_[sample_idx]->input.c_str()))); } return data; } -std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vector& outputs) { +std::vector MmluGen::ProcessOutput(const int sample_idx, + const std::vector& outputs) { if (sample_idx >= samples_.size() || outputs.empty()) return {0}; - sample_output_token_counts_[sample_idx] = reinterpret_cast*>(outputs[1])->size(); + sample_output_token_counts_[sample_idx] = + reinterpret_cast*>(outputs[1])->size(); const char* prediction = reinterpret_cast(outputs[0]); - char predicted_char = prediction[1]; // Assume second token is the answer because of whitespace (e.g., 'A', 'B', ...) - std::cout << "MMLUT-DATASET: " << "Predicted answer: " << predicted_char << std::endl; + char predicted_char = + prediction[1]; // Assume second token is the answer because of whitespace + // (e.g., 'A', 'B', ...) const std::string& correct = samples_[sample_idx]->answer; bool is_correct = (predicted_char == correct[0]); total_++; if (is_correct) correct_++; - std::cout << "MMLUT-DATASET: " << "Accuracy: " << std::to_string(correct_) << "/" << std::to_string(total_) << std::endl; - return {static_cast(is_correct)}; } - int64_t MmluGen::GetOutputTokenCount(const int sample_idx) { return sample_output_token_counts_[sample_idx]; } -bool MmluGen::HasAccuracy() { - return true; -} +bool MmluGen::HasAccuracy() { return true; } float MmluGen::ComputeAccuracy() { return total_ > 0 ? static_cast(correct_) / total_ : 0.0f; @@ -90,11 +88,5 @@ std::string MmluGen::ComputeAccuracyString() { return "Accuracy: " + std::to_string(acc * 100.0f) + "%"; } -//void MmluGen::loadSentencePieceProcessor(std::string path) { -// std::ifstream input(path, std::ios::binary); -// std::string serialized_proto = std::string(std::istreambuf_iterator(input), std::istreambuf_iterator()); -// if(!sp_processor->LoadFromSerializedProto(serialized_proto).ok()) LOG(FATAL) << "Could not load SP Processor"; -//} - } // namespace mobile } // namespace mlperf diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 7a2a96705..24a2ae9dd 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -5,12 +5,11 @@ #include #include +#include #include #include #include -#include -//#include "src/sentencepiece_processor.h" #include "flutter/cpp/dataset.h" #include "flutter/cpp/datasets/squad_utils/tfrecord_reader.h" @@ -29,11 +28,13 @@ class MmluGen : public Dataset { void LoadSamplesToRam(const std::vector& samples) override; - void UnloadSamplesFromRam(const std::vector& samples) override; + void UnloadSamplesFromRam( + const std::vector& samples) override; std::vector GetData(int sample_idx) override; - std::vector ProcessOutput(const int sample_idx, const std::vector& outputs) override; + std::vector ProcessOutput( + const int sample_idx, const std::vector& outputs) override; int64_t GetOutputTokenCount(const int sample_idx) override; @@ -43,14 +44,10 @@ class MmluGen : public Dataset { std::string ComputeAccuracyString() override; - private: - //void loadSentencePieceProcessor(std::string path); - const std::string name_ = "MmluGen"; TFRecordReader sample_reader_; - //sentencepiece::SentencePieceProcessor sp_processor; struct PromptSample { std::string input; From fca2905a7aab1f0f675321add12e2f135b593934 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 9 Sep 2025 02:45:26 +0300 Subject: [PATCH 15/74] slight code cleanup --- mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc | 1 - mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 4fbd09f87..805b09f8a 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -29,7 +29,6 @@ limitations under the License. #include "flutter/cpp/utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/delegates/gpu/delegate.h" #ifdef __cplusplus extern "C" { diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index d991c5d1a..9cb53a5c5 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -58,7 +58,7 @@ class AlignedAllocator { T *allocate(std::size_t n) { void *ptr; std::size_t size = n * sizeof(T); - // NOTE this part of the code from seems to be redundant + // NOTE this part of the code seems to be redundant // std::size_t padding = tflite::kDefaultTensorAlignment - // (size % tflite::kDefaultTensorAlignment); // size += padding; From 20e7805b9f1c94e499561663a0fa47a2132b0826 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 9 Sep 2025 05:59:55 +0300 Subject: [PATCH 16/74] fixed issue with genai ops import --- mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 9cb53a5c5..047b91324 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -24,6 +24,7 @@ limitations under the License. #include "pipeline.h" #include "src/sentencepiece_processor.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/lite/experimental/genai/genai_ops.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/kernels/register.h" From 83aea46e065bd1cfe6320f73fdd01b5f80a05cce Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sun, 28 Sep 2025 02:42:03 +0300 Subject: [PATCH 17/74] code/config cleanup --- .bazelrc | 4 + WORKSPACE | 18 ++-- flutter/assets/icons/ic_task_llm_white.svg | 101 ++++++++++++++++++ flutter/assets/tasks.pbtxt | 18 ++-- flutter/cpp/backend.h | 3 +- flutter/cpp/backends/external.h | 3 +- flutter/cpp/binary/main.cc | 6 +- flutter/cpp/dataset.h | 3 +- flutter/cpp/flutter/dart_run_benchmark.cc | 2 +- flutter/cpp/mlperf_driver.cc | 20 ++-- flutter/cpp/proto/mlperf_task.proto | 2 +- flutter/lib/ui/icons.dart | 6 +- mobile_back_tflite/cpp/backend_tflite/BUILD | 14 +-- .../tflite_settings_android.pbtxt | 8 +- .../cpp/backend_tflite/embedding_utils.h | 2 +- .../backend_tflite/single_model_pipeline.h | 5 +- .../stable_diffusion_pipeline.h | 5 +- .../cpp/backend_tflite/tflite_c.cc | 14 +-- 18 files changed, 176 insertions(+), 58 deletions(-) create mode 100644 flutter/assets/icons/ic_task_llm_white.svg diff --git a/.bazelrc b/.bazelrc index f54b7b828..91a29048e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -48,6 +48,10 @@ build:android_x86_64 --config=android build:android_x86_64 --cpu=x86_64 build:android_x86_64 --fat_apk_cpu=x86_64 + +build:android_x86_64 --define=xnn_enable_avx512fp16=false +build:android_x86_64 --define=xnn_enable_avxvnniint8=false + # Linux configs build:linux_x86_64 --config=linux build:linux_x86_64 --cpu=k8 diff --git a/WORKSPACE b/WORKSPACE index 7ab14ceef..59d3edd9f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -48,33 +48,35 @@ http_archive( ) load("@org_tensorflow//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") load("@org_tensorflow//third_party/gpus:rocm_configure.bzl", "rocm_configure") + rocm_configure(name = "local_config_rocm") http_archive( name = "com_google_sentencepiece", - strip_prefix = "sentencepiece-0.1.96", + build_file = "@//patches:sentencepiece.BUILD", + patch_args = ["-p1"], + patches = ["@//patches:com_google_sentencepiece.diff"], sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754", + strip_prefix = "sentencepiece-0.1.96", urls = [ - "https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip" + "https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip", ], - build_file = "@//patches:sentencepiece.BUILD", - patches = ["@//patches:com_google_sentencepiece.diff"], - patch_args = ["-p1"], ) http_archive( name = "darts_clone", + build_file = "@//patches:darts_clone.BUILD", + patch_args = ["-p0"], + patches = ["//patches:darts_no_exceptions.diff"], sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983", urls = [ "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", ], - build_file = "@//patches:darts_clone.BUILD", - patches = ["//patches:darts_no_exceptions.diff"], - patch_args = ["-p0"], ) load( diff --git a/flutter/assets/icons/ic_task_llm_white.svg b/flutter/assets/icons/ic_task_llm_white.svg new file mode 100644 index 000000000..08d0d1c8d --- /dev/null +++ b/flutter/assets/icons/ic_task_llm_white.svg @@ -0,0 +1,101 @@ + + + + + + + + + + + diff --git a/flutter/assets/tasks.pbtxt b/flutter/assets/tasks.pbtxt index 54c59a859..51acd52df 100644 --- a/flutter/assets/tasks.pbtxt +++ b/flutter/assets/tasks.pbtxt @@ -344,17 +344,17 @@ task { scenario: "SingleStream" runs { normal { - min_query_count: 1024 + min_query_count: 100 min_duration: 60 max_duration: 300 } quick { - min_query_count: 128 + min_query_count: 10 min_duration: 10 max_duration: 40 } rapid { - min_query_count: 64 + min_query_count: 6 min_duration: 6 max_duration: 60 } @@ -363,22 +363,22 @@ task { type: MMLU full { name: "TinyMMLU prompt set for LLM" - input_path: "https://thee.dev/mlc/data.tfrecord" - input_checksum: "b564d2c228a867148fa7d6df415a0368" + input_path: "local:///mlperf_datasets/tinymmlu/data.tfrecord" + input_checksum: "c20f9115582217af15e4d9955b41ace1" groundtruth_path: "" groundtruth_checksum: "" } lite { name: "TinyMMLU prompt set for LLM" - input_path: "https://thee.dev/mlc/data.tfrecord" - input_checksum: "b564d2c228a867148fa7d6df415a0368" + input_path: "local:///mlperf_datasets/tinymmlu/data.tfrecord" + input_checksum: "c20f9115582217af15e4d9955b41ace1" groundtruth_path: "" groundtruth_checksum: "" } tiny { name: "TinyMMLU prompt set for LLM" - input_path: "https://thee.dev/mlc/data.tfrecord" - input_checksum: "b564d2c228a867148fa7d6df415a0368" + input_path: "local:///mlperf_datasets/tinymmlu/data.tfrecord" + input_checksum: "c20f9115582217af15e4d9955b41ace1" groundtruth_path: "" groundtruth_checksum: "" } diff --git a/flutter/cpp/backend.h b/flutter/cpp/backend.h index e91454b55..50d970481 100644 --- a/flutter/cpp/backend.h +++ b/flutter/cpp/backend.h @@ -44,7 +44,8 @@ class Backend { // Accelerator name. virtual const std::string& AcceleratorName() const = 0; - // Run inference for token based input (such as LLM prompt). Only needed for LLMs currently. + // Run inference for token based input (such as LLM prompt). Only needed for + // LLMs currently. virtual void IssueFirstTokenQuery() = 0; // Run inference for a sample. Inputs is already set by SetInputs. diff --git a/flutter/cpp/backends/external.h b/flutter/cpp/backends/external.h index 1d9d4e02c..77a1f1aca 100644 --- a/flutter/cpp/backends/external.h +++ b/flutter/cpp/backends/external.h @@ -160,7 +160,8 @@ class ExternalBackend : public Backend { } void IssueFirstTokenQuery() override { - if (backend_functions_.issue_first_token_query(backend_ptr_) != MLPERF_SUCCESS) { + if (backend_functions_.issue_first_token_query(backend_ptr_) != + MLPERF_SUCCESS) { LOG(FATAL) << "Error while inferencing model for first token"; } } diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index fc97e3722..dc62634e0 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -437,9 +437,9 @@ int Main(int argc, char *argv[]) { // Running mlperf. MlperfDriver driver(std::move(dataset), std::move(backend), scenario, batch_size); - driver.RunMLPerfTest(mode, min_query_count, min_duration_ms / 1000.0, - max_duration_ms / 1000.0, - single_stream_expected_latency_ns, output_dir, benchmark_id=="llm"); + driver.RunMLPerfTest( + mode, min_query_count, min_duration_ms / 1000.0, max_duration_ms / 1000.0, + single_stream_expected_latency_ns, output_dir, benchmark_id == "llm"); LOG(INFO) << "Accuracy: " << driver.ComputeAccuracyString(); return 0; } diff --git a/flutter/cpp/dataset.h b/flutter/cpp/dataset.h index ebdd3fc3a..9886c0537 100644 --- a/flutter/cpp/dataset.h +++ b/flutter/cpp/dataset.h @@ -61,8 +61,7 @@ class Dataset : public ::mlperf::QuerySampleLibrary { const int sample_idx, const std::vector& outputs) = 0; // Should be called after ProcessOutput. - virtual int64_t GetOutputTokenCount( - const int sample_idx) {return 0;} + virtual int64_t GetOutputTokenCount(const int sample_idx) { return 0; } virtual bool HasAccuracy() { return false; } diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index 9030496c9..637b59e69 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -12,9 +12,9 @@ #include "flutter/cpp/datasets/coco.h" #include "flutter/cpp/datasets/coco_gen.h" #include "flutter/cpp/datasets/imagenet.h" +#include "flutter/cpp/datasets/mmlu_gen.h" #include "flutter/cpp/datasets/snu_sr.h" #include "flutter/cpp/datasets/squad.h" -#include "flutter/cpp/datasets/mmlu_gen.h" #include "flutter/cpp/mlperf_driver.h" #include "flutter/cpp/proto/backend_setting.pb.h" #include "flutter/cpp/proto/mlperf_task.pb.h" diff --git a/flutter/cpp/mlperf_driver.cc b/flutter/cpp/mlperf_driver.cc index 12f94fc53..7a1ea45cc 100644 --- a/flutter/cpp/mlperf_driver.cc +++ b/flutter/cpp/mlperf_driver.cc @@ -72,24 +72,23 @@ void MlperfDriver::IssueQuery( if (use_tokens_) { ft_responses.clear(); backend_->IssueFirstTokenQuery(); - ft_responses.push_back({sample.id, reinterpret_cast(nullptr), 0}); + ft_responses.push_back( + {sample.id, reinterpret_cast(nullptr), 0}); ::mlperf::FirstTokenComplete(ft_responses.data(), ft_responses.size()); } backend_->IssueQuery(); - // Report to mlperf. std::vector outputs = backend_->GetPredictedOutputs(); response_data.push_back(dataset_->ProcessOutput(sample.index, outputs)); - if (use_tokens_){ + if (use_tokens_) { responses.push_back( {sample.id, reinterpret_cast(response_data[idx].data()), response_data[idx].size(), dataset_->GetOutputTokenCount(sample.index)}); - } - else { + } else { responses.push_back( {sample.id, reinterpret_cast(response_data[idx].data()), @@ -105,7 +104,8 @@ void MlperfDriver::IssueQuery( void MlperfDriver::RunMLPerfTest(const std::string& mode, int min_query_count, double min_duration, double max_duration, int single_stream_expected_latency_ns, - const std::string& output_dir, bool use_tokens) { + const std::string& output_dir, + bool use_tokens) { ::mlperf::LogSettings log_settings; log_settings.log_output.outdir = output_dir; log_settings.log_output.copy_summary_to_stdout = true; @@ -116,12 +116,12 @@ void MlperfDriver::RunMLPerfTest(const std::string& mode, int min_query_count, mlperf_settings.sample_index_rng_seed = 10688027786191513374UL; mlperf_settings.schedule_rng_seed = 14962580496156340209UL; - //mlperf_settings.min_query_count = 1; - //mlperf_settings.max_query_count = 2; - //mlperf_settings.performance_sample_count_override = 5; + // mlperf_settings.min_query_count = 1; + // mlperf_settings.max_query_count = 2; + // mlperf_settings.performance_sample_count_override = 5; use_tokens_ = use_tokens; mlperf_settings.use_token_latencies = use_tokens; - //mlperf_settings.server_target_qps = 0.1; + // mlperf_settings.server_target_qps = 0.1; mlperf_settings.mode = Str2TestMode(mode); mlperf_settings.min_duration_ms = static_cast(std::ceil(min_duration * 1000.0)); diff --git a/flutter/cpp/proto/mlperf_task.proto b/flutter/cpp/proto/mlperf_task.proto index 3eeb843a6..e87f47d1c 100644 --- a/flutter/cpp/proto/mlperf_task.proto +++ b/flutter/cpp/proto/mlperf_task.proto @@ -69,7 +69,7 @@ message OneRunConfig { // Datasets for a task // -// Next ID: 5 +// Next ID: 8 message DatasetConfig { // Type of the dataset. enum DatasetType { diff --git a/flutter/lib/ui/icons.dart b/flutter/lib/ui/icons.dart index e47263e61..4dc73d5b5 100644 --- a/flutter/lib/ui/icons.dart +++ b/flutter/lib/ui/icons.dart @@ -28,8 +28,7 @@ class AppIcons { _pSvg('ic_task_super_resolution.svg'); static final SvgPicture stableDiffusion = _pSvg('ic_task_stable_diffusion.svg'); - static final SvgPicture llm = - _pSvg('ic_task_llm.svg'); + static final SvgPicture llm = _pSvg('ic_task_llm.svg'); static final SvgPicture imageClassificationWhite = _pSvg('ic_task_image_classification_white.svg'); @@ -45,6 +44,7 @@ class AppIcons { _pSvg('ic_task_super_resolution_white.svg'); static final SvgPicture stableDiffusionWhite = _pSvg('ic_task_stable_diffusion_white.svg'); + static final SvgPicture llmWhite = _pSvg('ic_task_llm_white.svg'); static final SvgPicture arrow = _pSvg('ic_arrow.svg'); @@ -84,7 +84,7 @@ class BenchmarkIcons { BenchmarkId.stableDiffusion: AppIcons.stableDiffusionWhite, BenchmarkId.imageClassificationOfflineV2: AppIcons.imageClassificationOfflineWhite, - BenchmarkId.llm: AppIcons.llm, + BenchmarkId.llm: AppIcons.llmWhite, }; static Widget getDarkIcon(String benchmarkId) => diff --git a/mobile_back_tflite/cpp/backend_tflite/BUILD b/mobile_back_tflite/cpp/backend_tflite/BUILD index c93e92471..3c1cbefdd 100644 --- a/mobile_back_tflite/cpp/backend_tflite/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/BUILD @@ -17,6 +17,8 @@ load( "tflite_copts", "tflite_jni_binary", ) +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//flutter/cpp/proto:pbtxt2header.bzl", "pbtxt2header") package( @@ -50,21 +52,21 @@ cc_library( name = "tflite_c", srcs = [ "embedding_utils.cc", + "llm_pipeline.cc", "sd_utils.cc", "single_model_pipeline.cc", "stable_diffusion_invoker.cc", "stable_diffusion_pipeline.cc", - "llm_pipeline.cc", "tflite_c.cc", ], hdrs = [ "embedding_utils.h", + "llm_pipeline.h", "pipeline.h", "sd_utils.h", "single_model_pipeline.h", "stable_diffusion_invoker.h", "stable_diffusion_pipeline.h", - "llm_pipeline.h", "tflite_settings_android.h", "tflite_settings_apple.h", "tflite_settings_windows.h", @@ -87,17 +89,17 @@ cc_library( "//flutter/cpp/c:headers", "@com_google_sentencepiece//:sentencepiece_processor", "@org_tensorflow//tensorflow/core:tflite_portable_logging", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:util", "@org_tensorflow//tensorflow/lite/c:c_api", "@org_tensorflow//tensorflow/lite/c:c_api_experimental", "@org_tensorflow//tensorflow/lite/c:common", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite:util", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ "@org_tensorflow//tensorflow:android": [ - "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:delegate", + "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", ], "@org_tensorflow//tensorflow:ios": [ "@org_tensorflow//tensorflow/lite/delegates/coreml:coreml_delegate", diff --git a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt index 5df21de7d..a784857a5 100644 --- a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt +++ b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt @@ -266,8 +266,12 @@ benchmark_setting { accelerator_name: "cpu" accelerator_desc: "CPU" model_file: { - model_path: "https://thee.dev/mlc/model.tflite" #Placeholder - model_checksum: "04f62ae20a0f1c68c138f30d88411be0" + model_path: "local:///mlperf_models/llama_q8_ekv3072.tflite" + model_checksum: "54efe0be372b55303673245067beef62" + } + model_file: { + model_path: "local:///mlperf_models/llama3_1b.spm.model" + model_checksum: "2ad260fc18b965ce16006d76c9327082" } } delegate_selected: "CPU" diff --git a/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h b/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h index 74951ed1b..4d972cb3b 100644 --- a/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h +++ b/mobile_back_tflite/cpp/backend_tflite/embedding_utils.h @@ -1,12 +1,12 @@ #ifndef EMBEDDING_UTILS_H_ #define EMBEDDING_UTILS_H_ +#include #include #include #include #include #include -#include class TsEmbeddingParser { public: diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h index 78f09a66a..13aad1c43 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h @@ -37,9 +37,10 @@ class SingleModelPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) override {return MLPERF_FAILURE;} + mlperf_backend_ptr_t backend_ptr) override { + return MLPERF_FAILURE; + } mlperf_status_t backend_issue_query( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h index 0b4046152..42f6b5725 100644 --- a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h @@ -64,9 +64,10 @@ class StableDiffusionPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) override {return MLPERF_FAILURE;} + mlperf_backend_ptr_t backend_ptr) override { + return MLPERF_FAILURE; + } mlperf_status_t backend_issue_query( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc index 4639b54cc..cf28b8bca 100644 --- a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc +++ b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc @@ -10,9 +10,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include + +#include "llm_pipeline.h" #include "single_model_pipeline.h" #include "stable_diffusion_pipeline.h" -#include "llm_pipeline.h" #include "tensorflow/core/platform/logging.h" #include "tflite_settings_android.h" #include "tflite_settings_apple.h" @@ -39,7 +40,7 @@ extern "C" { std::unique_ptr pipeline; void init_pipeline(const char *pipeline_type) { - //TODO use a switch/case + // TODO use a switch/case bool sd_pipeline = (strcmp(pipeline_type, "StableDiffusionPipeline") == 0); bool llm_pipeline = (strcmp(pipeline_type, "LLMPipeline") == 0); if (sd_pipeline) { @@ -48,8 +49,7 @@ void init_pipeline(const char *pipeline_type) { } else if (llm_pipeline) { LOG(INFO) << "Initializing LLMPipeline"; pipeline = std::make_unique(); - } - else { + } else { LOG(INFO) << "Initializing SingleModelPipeline"; pipeline = std::make_unique(); } @@ -153,7 +153,8 @@ bool mlperf_backend_matches_hardware(const char **not_allowed_message, mlperf_backend_ptr_t mlperf_backend_create( const char *model_path, mlperf_backend_configuration_t *configs, const char *native_lib_path) { - LOG(INFO) << "Using TfLite " << TfLiteVersion() << " With Schema " << TfLiteSchemaVersion() << std::endl; + LOG(INFO) << "Using TfLite " << TfLiteVersion() << " With Schema " + << TfLiteSchemaVersion() << std::endl; const char *pipeline_type = ""; for (int i = 0; i < configs->count; ++i) { if (strcmp(configs->keys[i], "pipeline") == 0) { @@ -186,7 +187,8 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { reset_pipeline(); } -mlperf_status_t mlperf_backend_issue_first_token_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t mlperf_backend_issue_first_token_query( + mlperf_backend_ptr_t backend_ptr) { return pipeline->backend_issue_first_token_query(backend_ptr); } From 61a5c8a9c8c346af02374915fde99b8617d238ad Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sun, 28 Sep 2025 04:29:13 +0300 Subject: [PATCH 18/74] add zero-shot option to MMLU constructor --- flutter/cpp/binary/main.cc | 6 +++++- flutter/cpp/datasets/mmlu_gen.cc | 6 ++++-- flutter/cpp/datasets/mmlu_gen.h | 2 +- flutter/cpp/flutter/dart_run_benchmark.cc | 2 +- mobile_back_tflite/cpp/backend_tflite/BUILD | 2 -- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index dc62634e0..b5e442ad8 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -395,6 +395,7 @@ int Main(int argc, char *argv[]) { dataset_flags.end()); } break; case DatasetConfig::MMLU: { + bool zero_shot = false; LOG(INFO) << "TinyMMLU dataset for LLM benchmark"; std::string input_tfrecord, input_clip_model = ""; std::vector dataset_flags{ @@ -402,11 +403,14 @@ int Main(int argc, char *argv[]) { "input_tfrecord", &input_tfrecord, "Path to the tfrecord file containing inputs for the model.", Flag::kRequired), + Flag::CreateFlag( + "zero-shot", &zero_shot, + "Use zero-shot prompts instead of the default few-shot."), }; if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && backend) { - dataset.reset(new MmluGen(backend.get(), input_tfrecord)); + dataset.reset(new MmluGen(backend.get(), input_tfrecord, zero_shot)); } // Adds to flag_list for showing help. flag_list.insert(flag_list.end(), dataset_flags.begin(), diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 3cd8c22b6..d8b0d796e 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -9,11 +9,11 @@ namespace mlperf { namespace mobile { -MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) +MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, bool zero_shot) : sample_reader_(input_tfrecord), Dataset(backend) { // Load all TFRecord samples into memory // NOTE this can be moved to LoadSamplesToRam, but will cause delays between - // queries due to IO reads happening between queries + // queries due to IO reads happening between them for (size_t i = 0; i < sample_reader_.Size(); i++) { tensorflow::tstring record = sample_reader_.ReadRecord(i); tensorflow::Example example; @@ -23,6 +23,8 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord) std::string answer = tensorflow::GetFeatureValues("answer", example).Get(0); + if (zero_shot) input = input.substr(input.rfind("\n\n")+2); // input-formatted shots are separated by 2 new lines + auto sample = std::make_unique(); sample->input = input; sample->answer = answer; diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 24a2ae9dd..7e7cc96c5 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -18,7 +18,7 @@ namespace mobile { class MmluGen : public Dataset { public: - MmluGen(Backend* backend, const std::string& input_tfrecord); + MmluGen(Backend* backend, const std::string& input_tfrecord, bool zero_shot); const std::string& Name() override { return name_; } diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index 637b59e69..fd9c9408f 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -108,7 +108,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( break; case ::mlperf::mobile::DatasetConfig::MMLU: dataset = std::make_unique<::mlperf::mobile::MmluGen>( - backend.get(), in->dataset_data_path); + backend.get(), in->dataset_data_path, true /*zero-shot*/); break; default: return nullptr; diff --git a/mobile_back_tflite/cpp/backend_tflite/BUILD b/mobile_back_tflite/cpp/backend_tflite/BUILD index 3c1cbefdd..1bd13343c 100644 --- a/mobile_back_tflite/cpp/backend_tflite/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/BUILD @@ -17,8 +17,6 @@ load( "tflite_copts", "tflite_jni_binary", ) -load("@rules_cc//cc:cc_binary.bzl", "cc_binary") -load("@rules_cc//cc:cc_library.bzl", "cc_library") load("//flutter/cpp/proto:pbtxt2header.bzl", "pbtxt2header") package( From 54adcd0aeb907e05b314d00e8188a0a2c8088b21 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 29 Sep 2025 03:57:34 +0300 Subject: [PATCH 19/74] use function to detect which token is answer letter --- flutter/cpp/datasets/mmlu_gen.cc | 40 +++++++++++++++++++++++++++++--- flutter/cpp/datasets/mmlu_gen.h | 2 ++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index d8b0d796e..6f0677ba8 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -63,10 +63,12 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, sample_output_token_counts_[sample_idx] = reinterpret_cast*>(outputs[1])->size(); const char* prediction = reinterpret_cast(outputs[0]); - char predicted_char = - prediction[1]; // Assume second token is the answer because of whitespace - // (e.g., 'A', 'B', ...) + + char predicted_char = find_answer_char(prediction); const std::string& correct = samples_[sample_idx]->answer; + + LOG(INFO) << "expected " << correct << " got " << predicted_char << std::endl; + bool is_correct = (predicted_char == correct[0]); total_++; @@ -90,5 +92,37 @@ std::string MmluGen::ComputeAccuracyString() { return "Accuracy: " + std::to_string(acc * 100.0f) + "%"; } +char MmluGen::find_answer_char(const char* input) { + if (!input) return 0; + + const unsigned char* c = reinterpret_cast(input); + + while (*c) { + // skip leading whitespace + while (*c && std::isspace(*c)) ++c; + if (!*c) break; + + const unsigned char* start = c; // start of word + + // quick check: is the word exactly 1 char long? + ++c; // move to potential second char + if (!*c || std::isspace(*c) || *c == '<') { + if (*start == 'A' || + *start == 'B' || + *start == 'C' || + *start == 'D' || + *start == 'a' || + *start == 'b' || + *start == 'c' || + *start == 'd') + return *start; + } else { + // skip rest of this (longer) word quickly + while (*c && !std::isspace(*c)) ++c; + } + } + return 0; +} + } // namespace mobile } // namespace mlperf diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 7e7cc96c5..f149c7e58 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -47,6 +47,8 @@ class MmluGen : public Dataset { private: const std::string name_ = "MmluGen"; + char find_answer_char(const char* input); + TFRecordReader sample_reader_; struct PromptSample { From 65f797f0b8dad75c2d49b3de2a21c844f7ff6ad3 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 29 Sep 2025 04:40:52 +0300 Subject: [PATCH 20/74] quick initial implementation of first token callback --- flutter/cpp/backend.h | 7 ++---- flutter/cpp/backends/external.cc | 2 -- flutter/cpp/backends/external.h | 15 +++-------- flutter/cpp/c/type.h | 2 ++ flutter/cpp/mlperf_driver.cc | 25 ++++++++++++------- .../cpp/backend_tflite/llm_pipeline.cc | 5 +++- .../cpp/backend_tflite/llm_pipeline.h | 4 +-- .../cpp/backend_tflite/pipeline.h | 5 +--- .../backend_tflite/single_model_pipeline.cc | 2 +- .../backend_tflite/single_model_pipeline.h | 7 +----- .../stable_diffusion_pipeline.cc | 2 +- .../stable_diffusion_pipeline.h | 7 +----- .../cpp/backend_tflite/tflite_c.cc | 11 +++----- 13 files changed, 38 insertions(+), 56 deletions(-) diff --git a/flutter/cpp/backend.h b/flutter/cpp/backend.h index 50d970481..6a73aee0d 100644 --- a/flutter/cpp/backend.h +++ b/flutter/cpp/backend.h @@ -44,12 +44,9 @@ class Backend { // Accelerator name. virtual const std::string& AcceleratorName() const = 0; - // Run inference for token based input (such as LLM prompt). Only needed for - // LLMs currently. - virtual void IssueFirstTokenQuery() = 0; - // Run inference for a sample. Inputs is already set by SetInputs. - virtual void IssueQuery() = 0; + // TODO might be good to provide the callback and context along with the inputs if possible + virtual void IssueQuery(ft_callback callback, void* context) = 0; // Flush the staged queries immediately. virtual void FlushQueries() = 0; diff --git a/flutter/cpp/backends/external.cc b/flutter/cpp/backends/external.cc index 363c6863e..3675e4943 100644 --- a/flutter/cpp/backends/external.cc +++ b/flutter/cpp/backends/external.cc @@ -159,8 +159,6 @@ BackendFunctions::BackendFunctions(const std::string& lib_path) { destroy = reinterpret_cast(GetSymbol("mlperf_backend_delete")); - issue_first_token_query = reinterpret_cast( - GetSymbol("mlperf_backend_issue_first_token_query")); issue_query = reinterpret_cast( GetSymbol("mlperf_backend_issue_query")); flush_queries = reinterpret_cast( diff --git a/flutter/cpp/backends/external.h b/flutter/cpp/backends/external.h index 77a1f1aca..ca9a4296e 100644 --- a/flutter/cpp/backends/external.h +++ b/flutter/cpp/backends/external.h @@ -47,10 +47,8 @@ struct BackendFunctions { using AcceleratorNamePtr = std::add_pointer::type; using BackendDeletePtr = std::add_pointer::type; - using IssueFirstTokenQueryPtr = - std::add_pointer::type; using IssueQueryPtr = - std::add_pointer::type; + std::add_pointer::type; using FlushQueriesPtr = std::add_pointer::type; @@ -80,7 +78,6 @@ struct BackendFunctions { AcceleratorNamePtr accelerator_name{nullptr}; BackendDeletePtr destroy{nullptr}; - IssueFirstTokenQueryPtr issue_first_token_query{nullptr}; IssueQueryPtr issue_query{nullptr}; FlushQueriesPtr flush_queries{nullptr}; @@ -159,15 +156,9 @@ class ExternalBackend : public Backend { return accelerator_name_; } - void IssueFirstTokenQuery() override { - if (backend_functions_.issue_first_token_query(backend_ptr_) != - MLPERF_SUCCESS) { - LOG(FATAL) << "Error while inferencing model for first token"; - } - } // Run inference for a sample. - void IssueQuery() override { - if (backend_functions_.issue_query(backend_ptr_) != MLPERF_SUCCESS) { + void IssueQuery(ft_callback callback, void* context) override { + if (backend_functions_.issue_query(backend_ptr_, callback, context) != MLPERF_SUCCESS) { LOG(FATAL) << "Error while inferencing model"; } } diff --git a/flutter/cpp/c/type.h b/flutter/cpp/c/type.h index 69ba2242c..f47f7ee63 100644 --- a/flutter/cpp/c/type.h +++ b/flutter/cpp/c/type.h @@ -61,6 +61,8 @@ typedef struct { const char* native_lib_path; } mlperf_device_info_t; +typedef void (*ft_callback)(void* context); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/flutter/cpp/mlperf_driver.cc b/flutter/cpp/mlperf_driver.cc index 7a1ea45cc..343f627d1 100644 --- a/flutter/cpp/mlperf_driver.cc +++ b/flutter/cpp/mlperf_driver.cc @@ -28,6 +28,12 @@ limitations under the License. namespace mlperf { namespace mobile { +// A method to be called by the backend as soon as the first token is generated (only for token based benchmarks) +static void FirstTokenCallback(void* context) { + auto ft_responses = *(reinterpret_cast*>(context)); + ::mlperf::FirstTokenComplete(ft_responses.data(), ft_responses.size()); +} + void MlperfDriver::IssueQuery( const std::vector<::mlperf::QuerySample>& samples) { std::vector<::mlperf::QuerySampleResponse> responses; @@ -47,7 +53,12 @@ void MlperfDriver::IssueQuery( backend_->SetInputs(inputs, b); } - backend_->IssueQuery(); + // TODO maybe don't do these 2 lines for non token stuff + // TODO figure out what this vector sample variable is + ft_responses.clear(); + ft_responses.push_back({sample.back().index, reinterpret_cast(nullptr), 0}); + + backend_->IssueQuery(&FirstTokenCallback, reinterpret_cast(&ft_responses)); for (int b = 0; b < batch_; b++) { if (idx + b == samples.size()) break; // ignore extra data @@ -69,15 +80,11 @@ void MlperfDriver::IssueQuery( std::vector inputs = dataset_->GetData(sample.index); backend_->SetInputs(inputs); - if (use_tokens_) { - ft_responses.clear(); - backend_->IssueFirstTokenQuery(); - ft_responses.push_back( - {sample.id, reinterpret_cast(nullptr), 0}); - ::mlperf::FirstTokenComplete(ft_responses.data(), ft_responses.size()); - } + // TODO maybe don't do these 2 lines for non token stuff + ft_responses.clear(); + ft_responses.push_back({sample.id, reinterpret_cast(nullptr), 0}); - backend_->IssueQuery(); + backend_->IssueQuery(&FirstTokenCallback, reinterpret_cast(&ft_responses)); // Report to mlperf. std::vector outputs = backend_->GetPredictedOutputs(); diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 805b09f8a..05d231e57 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -167,9 +167,12 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query( // Run the output token producing decode inference. // This function exclusively takes output tokens to produce more output tokens. mlperf_status_t LLMPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr) { + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + backend_issue_first_token_query(backend_ptr); + callback(context); + int kv_cache_max_size = backend_data->tensors.kv_cache_k_0()->dims->data[1]; size_t input_size = backend_data->prompt_tokens.size(); diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 047b91324..24f212a7f 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -176,10 +176,10 @@ class LLMPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; mlperf_status_t backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) override; + mlperf_backend_ptr_t backend_ptr); mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr) override; + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) override; mlperf_status_t backend_flush_queries( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/pipeline.h b/mobile_back_tflite/cpp/backend_tflite/pipeline.h index 86747ffff..ad82c002a 100644 --- a/mobile_back_tflite/cpp/backend_tflite/pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/pipeline.h @@ -40,12 +40,9 @@ class Pipeline { // Return the name of this backend. virtual const char *backend_name(mlperf_backend_ptr_t backend_ptr) = 0; - virtual mlperf_status_t backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) = 0; - // Run the inference for a sample. virtual mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr) = 0; + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) = 0; // Flush the staged queries immediately. virtual mlperf_status_t backend_flush_queries( diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc index a75985373..bf60dddbf 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc @@ -425,7 +425,7 @@ const char *SingleModelPipeline::backend_name( // Run the inference for a sample. mlperf_status_t SingleModelPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr) { + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { TFLiteBackendData *backend_data = (TFLiteBackendData *)backend_ptr; #ifdef MTK_TFLITE_NEURON_BACKEND diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h index 13aad1c43..37b323b87 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h @@ -37,13 +37,8 @@ class SingleModelPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) override { - return MLPERF_FAILURE; - } - mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr) override; + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) override; mlperf_status_t backend_flush_queries( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.cc index fd61f4518..be6dcb247 100644 --- a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.cc @@ -175,7 +175,7 @@ void StableDiffusionPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { } mlperf_status_t StableDiffusionPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr) { + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { SDBackendData* backend_data = (SDBackendData*)backend_ptr; StableDiffusionInvoker* invoker = new StableDiffusionInvoker(backend_data); backend_data->output = invoker->invoke(); diff --git a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h index 42f6b5725..b2dbdbcb4 100644 --- a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h @@ -64,13 +64,8 @@ class StableDiffusionPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) override { - return MLPERF_FAILURE; - } - mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr) override; + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) override; mlperf_status_t backend_flush_queries( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc index cf28b8bca..1bb6b96a5 100644 --- a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc +++ b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc @@ -187,14 +187,11 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { reset_pipeline(); } -mlperf_status_t mlperf_backend_issue_first_token_query( - mlperf_backend_ptr_t backend_ptr) { - return pipeline->backend_issue_first_token_query(backend_ptr); -} - // Run the inference for a sample. -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { - return pipeline->backend_issue_query(backend_ptr); +// callback and context are only used when running token based inferences (LLM). +// In other cases they can be passed as nullptr +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { + return pipeline->backend_issue_query(backend_ptr, callback, context); } // Flush the staged queries immediately. From 719aefad52bb421ded9006e742efb1071bf562d2 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 29 Sep 2025 07:45:09 +0300 Subject: [PATCH 21/74] moved tokenizer to dataset side (possibly needs cleanup) --- flutter/cpp/binary/main.cc | 8 ++- flutter/cpp/datasets/BUILD | 2 + flutter/cpp/datasets/mmlu_gen.cc | 38 +++++++---- flutter/cpp/datasets/mmlu_gen.h | 13 +++- flutter/cpp/flutter/dart_run_benchmark.cc | 7 ++- mobile_back_tflite/cpp/backend_tflite/BUILD | 1 - .../cpp/backend_tflite/llm_pipeline.cc | 63 ++++--------------- .../cpp/backend_tflite/llm_pipeline.h | 8 --- 8 files changed, 63 insertions(+), 77 deletions(-) diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index b5e442ad8..ddc1ecda9 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -397,12 +397,16 @@ int Main(int argc, char *argv[]) { case DatasetConfig::MMLU: { bool zero_shot = false; LOG(INFO) << "TinyMMLU dataset for LLM benchmark"; - std::string input_tfrecord, input_clip_model = ""; + std::string input_tfrecord, sp_path = ""; std::vector dataset_flags{ Flag::CreateFlag( "input_tfrecord", &input_tfrecord, "Path to the tfrecord file containing inputs for the model.", Flag::kRequired), + Flag::CreateFlag( + "sp_path", &sp_path, + "Path to the sentencepiece model file.", + Flag::kRequired), Flag::CreateFlag( "zero-shot", &zero_shot, "Use zero-shot prompts instead of the default few-shot."), @@ -410,7 +414,7 @@ int Main(int argc, char *argv[]) { if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && backend) { - dataset.reset(new MmluGen(backend.get(), input_tfrecord, zero_shot)); + dataset.reset(new MmluGen(backend.get(), input_tfrecord, sp_path, zero_shot)); } // Adds to flag_list for showing help. flag_list.insert(flag_list.end(), dataset_flags.begin(), diff --git a/flutter/cpp/datasets/BUILD b/flutter/cpp/datasets/BUILD index ebcd71587..dd739b0de 100644 --- a/flutter/cpp/datasets/BUILD +++ b/flutter/cpp/datasets/BUILD @@ -229,6 +229,8 @@ cc_library( "//flutter/cpp:utils", "//flutter/cpp/backends:external", "//flutter/cpp/datasets/squad_utils", + "//flutter/cpp/datasets/mmlu_utils", + "@com_google_sentencepiece//:sentencepiece_processor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_protobuf//:protobuf", "@org_tensorflow//tensorflow/lite/tools/evaluation:utils", diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 6f0677ba8..8ab3c98da 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -3,14 +3,21 @@ #include #include +#include "flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature_util.h" namespace mlperf { namespace mobile { -MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, bool zero_shot) +// TODO add eos and bos tokens as config parameters +MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, const std::string& sp_path, bool zero_shot) : sample_reader_(input_tfrecord), Dataset(backend) { + + sp_processor = std::unique_ptr(LoadSentencePieceProcessor(sp_path)); + start_token_id = sp_processor->PieceToId(start_token); + end_token_id = sp_processor->PieceToId(end_token); + // Load all TFRecord samples into memory // NOTE this can be moved to LoadSamplesToRam, but will cause delays between // queries due to IO reads happening between them @@ -25,8 +32,15 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, bool zero_ if (zero_shot) input = input.substr(input.rfind("\n\n")+2); // input-formatted shots are separated by 2 new lines + std::vector input_tokens; + + sp_processor->Encode(input.c_str(), &input_tokens).ok(); + + input_tokens.insert(input_tokens.begin(), start_token_id); + auto sample = std::make_unique(); sample->input = input; + sample->input_tokens = input_tokens; sample->answer = answer; samples_.push_back(std::move(sample)); @@ -49,9 +63,10 @@ void MmluGen::UnloadSamplesFromRam( std::vector MmluGen::GetData(int sample_idx) { std::vector data; + if (sample_idx < samples_.size()) { - data.push_back(reinterpret_cast( - const_cast(samples_[sample_idx]->input.c_str()))); + data.push_back(reinterpret_cast(const_cast*>(&(samples_[sample_idx]->input_tokens)))); + data.push_back(reinterpret_cast(const_cast(&end_token_id))); } return data; } @@ -60,15 +75,17 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vector& outputs) { if (sample_idx >= samples_.size() || outputs.empty()) return {0}; - sample_output_token_counts_[sample_idx] = - reinterpret_cast*>(outputs[1])->size(); - const char* prediction = reinterpret_cast(outputs[0]); + const auto& output_tokens = *(reinterpret_cast*>(outputs[0])); + LOG(INFO) << "~getout " << output_tokens.size() << std::endl; + + sample_output_token_counts_[sample_idx] = output_tokens.size(); + + std::string prediction; + sp_processor->Decode(output_tokens, &prediction).ok(); char predicted_char = find_answer_char(prediction); const std::string& correct = samples_[sample_idx]->answer; - LOG(INFO) << "expected " << correct << " got " << predicted_char << std::endl; - bool is_correct = (predicted_char == correct[0]); total_++; @@ -92,10 +109,9 @@ std::string MmluGen::ComputeAccuracyString() { return "Accuracy: " + std::to_string(acc * 100.0f) + "%"; } -char MmluGen::find_answer_char(const char* input) { - if (!input) return 0; +char MmluGen::find_answer_char(const std::string& input) { - const unsigned char* c = reinterpret_cast(input); + const unsigned char* c = reinterpret_cast(input.c_str()); while (*c) { // skip leading whitespace diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index f149c7e58..9e05508ae 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -10,6 +10,7 @@ #include #include +#include "src/sentencepiece_processor.h" #include "flutter/cpp/dataset.h" #include "flutter/cpp/datasets/squad_utils/tfrecord_reader.h" @@ -18,7 +19,7 @@ namespace mobile { class MmluGen : public Dataset { public: - MmluGen(Backend* backend, const std::string& input_tfrecord, bool zero_shot); + MmluGen(Backend* backend, const std::string& input_tfrecord, const std::string& sp_path, bool zero_shot); const std::string& Name() override { return name_; } @@ -47,21 +48,29 @@ class MmluGen : public Dataset { private: const std::string name_ = "MmluGen"; - char find_answer_char(const char* input); + char find_answer_char(const std::string& input); TFRecordReader sample_reader_; struct PromptSample { std::string input; + std::vector input_tokens; std::string answer; }; std::vector> samples_; std::vector sample_output_token_counts_; std::set loaded_sample_ids_; + std::unique_ptr sp_processor; size_t correct_ = 0; size_t total_ = 0; + + + std::string start_token = ""; + std::string end_token = ""; + int start_token_id; + int end_token_id; }; } // namespace mobile diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index fd9c9408f..49d804d73 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -74,6 +74,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( out->accelerator_name = strdup(backend->AcceleratorName().c_str()); ::std::unique_ptr<::mlperf::mobile::Dataset> dataset; + std::string sp_path; switch (in->dataset_type) { case ::mlperf::mobile::DatasetConfig::IMAGENET: dataset = std::make_unique<::mlperf::mobile::Imagenet>( @@ -107,8 +108,12 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( in->output_dir); break; case ::mlperf::mobile::DatasetConfig::MMLU: + for (auto setting : settings.benchmark_setting().custom_setting()) + { + if (setting.id() == "llm_tokenizer_path") sp_path = setting.value(); + } dataset = std::make_unique<::mlperf::mobile::MmluGen>( - backend.get(), in->dataset_data_path, true /*zero-shot*/); + backend.get(), in->dataset_data_path, sp_path, true /*zero-shot*/); break; default: return nullptr; diff --git a/mobile_back_tflite/cpp/backend_tflite/BUILD b/mobile_back_tflite/cpp/backend_tflite/BUILD index 1bd13343c..9697c5a1f 100644 --- a/mobile_back_tflite/cpp/backend_tflite/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/BUILD @@ -85,7 +85,6 @@ cc_library( ":tflite_settings", "//flutter/cpp:utils", "//flutter/cpp/c:headers", - "@com_google_sentencepiece//:sentencepiece_processor", "@org_tensorflow//tensorflow/core:tflite_portable_logging", "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:util", diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 05d231e57..24edb9e11 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -44,7 +44,6 @@ void LLMPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { } // Create a new backend and return the pointer to it. -// TODO add eos and bos tokens as config parameters mlperf_backend_ptr_t LLMPipeline::backend_create( const char *model_path, mlperf_backend_configuration_t *configs, const char *native_lib_path) { @@ -56,10 +55,6 @@ mlperf_backend_ptr_t LLMPipeline::backend_create( LLMBackendData *backend_data = new LLMBackendData(); - // sentencePiece Processor Path - std::string sppp = mlperf::mobile::GetConfigValue( - configs, "sentencepiece_processor_path", std::string("")); - // Load the model. backend_data->model = tflite::FlatBufferModel::BuildFromFile(model_path).release(); @@ -83,18 +78,6 @@ mlperf_backend_ptr_t LLMPipeline::backend_create( backend_data->decode_runner = GetDecodeRunner(backend_data->interpreter, backend_data->kv_cache); - backend_data->sp_processor = LoadSentencePieceProcessor(sppp); - if (!backend_data->sp_processor) { - LOG(ERROR) << "Failed to load sentencepiece processor: " << sppp; - backend_delete(backend_data); - return nullptr; - } - - if (!backend_data->end_token.empty()) { - backend_data->stop_token_id = - backend_data->sp_processor->PieceToId((backend_data->end_token)); - } - return backend_data; } @@ -209,10 +192,10 @@ mlperf_status_t LLMPipeline::backend_flush_queries( } // Return the number of inputs of the model. -// Only 1 input needs to be provided, which is the tokens themselves, the other +// Only 2 inputs need to be provided, the tokens themselves, and the EOS token. The other // inputs are handled by the pipeline int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { - return 1; + return 2; } // Return the type of the ith input. @@ -227,25 +210,19 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, void *data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - std::string prompt = std::string(static_cast(data)); + if (i == 1) { + backend_data->stop_token_id = *(reinterpret_cast(data)); + return MLPERF_SUCCESS; + } // Reset the tokens and kv caches from potential previous runs. - backend_data->prompt_tokens.clear(); backend_data->output_tokens.clear(); for (auto &[_, vec] : backend_data->kv_cache) { std::fill(vec.begin(), vec.end(), 0.0f); } - MINIMAL_CHECK( - backend_data->sp_processor->Encode(prompt, &backend_data->prompt_tokens) - .ok()); - - if (!backend_data->start_token.empty()) { - backend_data->prompt_tokens.insert( - backend_data->prompt_tokens.begin(), - backend_data->sp_processor->PieceToId((backend_data->start_token))); - } + backend_data->prompt_tokens = *(reinterpret_cast*>(data)); uint16_t effective_prefill_token_size = backend_data->prompt_tokens.size() - 1; // assuming max tokens is <16k @@ -275,7 +252,7 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, // Return the number of outputs for the model. int32_t LLMPipeline::backend_get_output_count( mlperf_backend_ptr_t backend_ptr) { - return 2; // 0 is the output string, 1 is the output tokens + return 1; // 0 is the output tokens } // Return the type of ith output. @@ -290,18 +267,10 @@ mlperf_status_t LLMPipeline::backend_get_output( void **data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - if (i == 0) { - MINIMAL_CHECK( - backend_data->sp_processor - ->Decode(backend_data->output_tokens, &backend_data->output) - .ok()); - LOG(INFO) << "Output: " << backend_data->output << std::endl; - - *data = backend_data->output.data(); - } else if (i == 1) { - *data = &backend_data->output_tokens; - } else + if (i != 0) return MLPERF_FAILURE; + + *data = reinterpret_cast(&backend_data->output_tokens); return MLPERF_SUCCESS; } @@ -421,16 +390,6 @@ tflite::SignatureRunner *LLMPipeline::GetDecodeRunner( return runner; } -sentencepiece::SentencePieceProcessor *LLMPipeline::LoadSentencePieceProcessor( - std::string path) { - std::ifstream input(path, std::ios::binary); - std::string serialized_proto = std::string( - std::istreambuf_iterator(input), std::istreambuf_iterator()); - auto processor = new sentencepiece::SentencePieceProcessor(); - MINIMAL_CHECK_PTR(processor->LoadFromSerializedProto(serialized_proto).ok()); - return processor; -} - // A basic greedy sampler (equivalent to argmax). int LLMPipeline::GreedySampler(const TfLiteTensor *logits) { float max_value = -std::numeric_limits::infinity(); diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 24f212a7f..dde5dbfe4 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -22,7 +22,6 @@ limitations under the License. #include "flutter/cpp/c/type.h" #include "pipeline.h" -#include "src/sentencepiece_processor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/experimental/genai/genai_ops.h" #include "tensorflow/lite/interpreter.h" @@ -125,7 +124,6 @@ struct LLMBackendData { const char *vendor = "Google"; const char *accelerator = "CPU"; tflite::FlatBufferModel *model{nullptr}; - sentencepiece::SentencePieceProcessor *sp_processor{nullptr}; // TfLiteInterpreterOptions *options{}; TODO use this to allow different // delegates other than CPU? tflite::Interpreter *interpreter{}; @@ -135,18 +133,14 @@ struct LLMBackendData { kv_cache_t kv_cache; std::vector prompt_tokens; std::vector output_tokens; - std::string output; uint8_t threads = 30; int max_output_tokens = 2; - std::string start_token = ""; - std::string end_token = ""; int stop_token_id = -1; LLMBackendData() {} ~LLMBackendData() { // Runners are owned by interpreter and therefore don't need to be deleted - delete sp_processor; delete interpreter; delete model; } @@ -222,8 +216,6 @@ class LLMPipeline : public Pipeline { kv_cache_t &kv_cache); tflite::SignatureRunner *GetDecodeRunner(tflite::Interpreter *interpreter, kv_cache_t &kv_cache); - sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor( - std::string path); int GreedySampler(const TfLiteTensor *logits); }; From 765817e47419ba9083e664a159b34075ea95c002 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sun, 5 Oct 2025 14:08:41 +0300 Subject: [PATCH 22/74] added files needed for MMLU utils --- flutter/cpp/datasets/mmlu_utils/BUILD | 38 +++++++++++++++++++ .../datasets/mmlu_utils/sentencepiece_utils.h | 38 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 flutter/cpp/datasets/mmlu_utils/BUILD create mode 100644 flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h diff --git a/flutter/cpp/datasets/mmlu_utils/BUILD b/flutter/cpp/datasets/mmlu_utils/BUILD new file mode 100644 index 000000000..7470937a7 --- /dev/null +++ b/flutter/cpp/datasets/mmlu_utils/BUILD @@ -0,0 +1,38 @@ +# Copyright 2025 The MLPerf Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "mmlu_utils", + hdrs = [ + "sentencepiece_utils.h" + ], + copts = select({ + "//flutter/android/commonlibs:use_asan": [ + "-fsanitize=address", + "-g", + "-O1", + "-fno-omit-frame-pointer", + ], + "//conditions:default": [], + }), + deps = [ + "@com_google_sentencepiece//:sentencepiece_processor" + ] +) diff --git a/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h b/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h new file mode 100644 index 000000000..2cde7af08 --- /dev/null +++ b/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2025 The MLPerf Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLPERF_DATASETS_MMLU_UTILS_SENTENCEPIECE_UTILS_H_ +#define MLPERF_DATASETS_MMLU_UTILS_SENTENCEPIECE_UTILS_H_ + +#include + +#include "src/sentencepiece_processor.h" + +namespace mlperf { +namespace mobile { + +static sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor( + std::string path) { + std::ifstream input(path, std::ios::binary); + std::string serialized_proto = std::string( + std::istreambuf_iterator(input), std::istreambuf_iterator()); + auto processor = new sentencepiece::SentencePieceProcessor(); + processor->LoadFromSerializedProto(serialized_proto).ok(); + return processor; +} + +} // namespace mobile +} // namespace mlperf +#endif // MLPERF_DATASETS_MMLU_UTILS_SENTENCEPIECE_UTILS_H_ From 2e887cdbc0c8347a3e1e12a2c51220b64b6b6fcf Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sun, 5 Oct 2025 14:13:33 +0300 Subject: [PATCH 23/74] clang-format --- flutter/cpp/datasets/mmlu_gen.cc | 37 +++++++++---------- flutter/cpp/datasets/mmlu_gen.h | 6 +-- .../datasets/mmlu_utils/sentencepiece_utils.h | 2 +- .../cpp/backend_tflite/llm_pipeline.cc | 17 +++++---- .../cpp/backend_tflite/llm_pipeline.h | 5 ++- 5 files changed, 34 insertions(+), 33 deletions(-) diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 8ab3c98da..67666ce97 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -11,10 +11,11 @@ namespace mlperf { namespace mobile { // TODO add eos and bos tokens as config parameters -MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, const std::string& sp_path, bool zero_shot) +MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, + const std::string& sp_path, bool zero_shot) : sample_reader_(input_tfrecord), Dataset(backend) { - - sp_processor = std::unique_ptr(LoadSentencePieceProcessor(sp_path)); + sp_processor = std::unique_ptr( + LoadSentencePieceProcessor(sp_path)); start_token_id = sp_processor->PieceToId(start_token); end_token_id = sp_processor->PieceToId(end_token); @@ -30,7 +31,10 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, const std: std::string answer = tensorflow::GetFeatureValues("answer", example).Get(0); - if (zero_shot) input = input.substr(input.rfind("\n\n")+2); // input-formatted shots are separated by 2 new lines + if (zero_shot) + input = input.substr( + input.rfind("\n\n") + + 2); // input-formatted shots are separated by 2 new lines std::vector input_tokens; @@ -65,7 +69,8 @@ std::vector MmluGen::GetData(int sample_idx) { std::vector data; if (sample_idx < samples_.size()) { - data.push_back(reinterpret_cast(const_cast*>(&(samples_[sample_idx]->input_tokens)))); + data.push_back(reinterpret_cast( + const_cast*>(&(samples_[sample_idx]->input_tokens)))); data.push_back(reinterpret_cast(const_cast(&end_token_id))); } return data; @@ -75,8 +80,8 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vector& outputs) { if (sample_idx >= samples_.size() || outputs.empty()) return {0}; - const auto& output_tokens = *(reinterpret_cast*>(outputs[0])); - LOG(INFO) << "~getout " << output_tokens.size() << std::endl; + const auto& output_tokens = + *(reinterpret_cast*>(outputs[0])); sample_output_token_counts_[sample_idx] = output_tokens.size(); @@ -110,27 +115,21 @@ std::string MmluGen::ComputeAccuracyString() { } char MmluGen::find_answer_char(const std::string& input) { - - const unsigned char* c = reinterpret_cast(input.c_str()); + const unsigned char* c = + reinterpret_cast(input.c_str()); while (*c) { // skip leading whitespace while (*c && std::isspace(*c)) ++c; if (!*c) break; - const unsigned char* start = c; // start of word + const unsigned char* start = c; // start of word // quick check: is the word exactly 1 char long? - ++c; // move to potential second char + ++c; // move to potential second char if (!*c || std::isspace(*c) || *c == '<') { - if (*start == 'A' || - *start == 'B' || - *start == 'C' || - *start == 'D' || - *start == 'a' || - *start == 'b' || - *start == 'c' || - *start == 'd') + if (*start == 'A' || *start == 'B' || *start == 'C' || *start == 'D' || + *start == 'a' || *start == 'b' || *start == 'c' || *start == 'd') return *start; } else { // skip rest of this (longer) word quickly diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 9e05508ae..8c21caaaa 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -10,16 +10,17 @@ #include #include -#include "src/sentencepiece_processor.h" #include "flutter/cpp/dataset.h" #include "flutter/cpp/datasets/squad_utils/tfrecord_reader.h" +#include "src/sentencepiece_processor.h" namespace mlperf { namespace mobile { class MmluGen : public Dataset { public: - MmluGen(Backend* backend, const std::string& input_tfrecord, const std::string& sp_path, bool zero_shot); + MmluGen(Backend* backend, const std::string& input_tfrecord, + const std::string& sp_path, bool zero_shot); const std::string& Name() override { return name_; } @@ -66,7 +67,6 @@ class MmluGen : public Dataset { size_t correct_ = 0; size_t total_ = 0; - std::string start_token = ""; std::string end_token = ""; int start_token_id; diff --git a/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h b/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h index 2cde7af08..f4a2a76fe 100644 --- a/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h +++ b/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h @@ -24,7 +24,7 @@ namespace mlperf { namespace mobile { static sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor( - std::string path) { + std::string path) { std::ifstream input(path, std::ios::binary); std::string serialized_proto = std::string( std::istreambuf_iterator(input), std::istreambuf_iterator()); diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 24edb9e11..a6ee4434f 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -150,7 +150,7 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query( // Run the output token producing decode inference. // This function exclusively takes output tokens to produce more output tokens. mlperf_status_t LLMPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void *context) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; backend_issue_first_token_query(backend_ptr); @@ -192,8 +192,8 @@ mlperf_status_t LLMPipeline::backend_flush_queries( } // Return the number of inputs of the model. -// Only 2 inputs need to be provided, the tokens themselves, and the EOS token. The other -// inputs are handled by the pipeline +// Only 2 inputs need to be provided, the tokens themselves, and the EOS token. +// The other inputs are handled by the pipeline int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { return 2; } @@ -211,7 +211,9 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; if (i == 1) { - backend_data->stop_token_id = *(reinterpret_cast(data)); + backend_data->stop_token_id = *(reinterpret_cast(data)); + LOG(INFO) << "stop token id: " + << std::to_string(backend_data->stop_token_id) << std::endl; return MLPERF_SUCCESS; } @@ -222,7 +224,7 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, std::fill(vec.begin(), vec.end(), 0.0f); } - backend_data->prompt_tokens = *(reinterpret_cast*>(data)); + backend_data->prompt_tokens = *(reinterpret_cast *>(data)); uint16_t effective_prefill_token_size = backend_data->prompt_tokens.size() - 1; // assuming max tokens is <16k @@ -267,10 +269,9 @@ mlperf_status_t LLMPipeline::backend_get_output( void **data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - if (i != 0) - return MLPERF_FAILURE; + if (i != 0) return MLPERF_FAILURE; - *data = reinterpret_cast(&backend_data->output_tokens); + *data = reinterpret_cast(&backend_data->output_tokens); return MLPERF_SUCCESS; } diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index dde5dbfe4..5ca662d60 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -172,8 +172,9 @@ class LLMPipeline : public Pipeline { mlperf_status_t backend_issue_first_token_query( mlperf_backend_ptr_t backend_ptr); - mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) override; + mlperf_status_t backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) override; mlperf_status_t backend_flush_queries( mlperf_backend_ptr_t backend_ptr) override; From a3b0799064f7d1209fb02d29126b97b06322ca11 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sun, 5 Oct 2025 15:10:31 +0300 Subject: [PATCH 24/74] continued formatting --- flutter/cpp/backends/external.h | 7 ++++--- flutter/cpp/binary/main.cc | 10 +++++----- flutter/cpp/flutter/dart_run_benchmark.cc | 3 +-- flutter/cpp/mlperf_driver.cc | 18 ++++++++++++------ .../cpp/backend_tflite/pipeline.h | 5 +++-- .../backend_tflite/single_model_pipeline.cc | 2 +- .../cpp/backend_tflite/single_model_pipeline.h | 5 +++-- .../backend_tflite/stable_diffusion_pipeline.h | 5 +++-- .../cpp/backend_tflite/tflite_c.cc | 4 +++- 9 files changed, 35 insertions(+), 24 deletions(-) diff --git a/flutter/cpp/backends/external.h b/flutter/cpp/backends/external.h index ca9a4296e..04faddcd8 100644 --- a/flutter/cpp/backends/external.h +++ b/flutter/cpp/backends/external.h @@ -47,8 +47,8 @@ struct BackendFunctions { using AcceleratorNamePtr = std::add_pointer::type; using BackendDeletePtr = std::add_pointer::type; - using IssueQueryPtr = - std::add_pointer::type; + using IssueQueryPtr = std::add_pointer::type; using FlushQueriesPtr = std::add_pointer::type; @@ -158,7 +158,8 @@ class ExternalBackend : public Backend { // Run inference for a sample. void IssueQuery(ft_callback callback, void* context) override { - if (backend_functions_.issue_query(backend_ptr_, callback, context) != MLPERF_SUCCESS) { + if (backend_functions_.issue_query(backend_ptr_, callback, context) != + MLPERF_SUCCESS) { LOG(FATAL) << "Error while inferencing model"; } } diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index ddc1ecda9..24e5cc4d2 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -403,10 +403,9 @@ int Main(int argc, char *argv[]) { "input_tfrecord", &input_tfrecord, "Path to the tfrecord file containing inputs for the model.", Flag::kRequired), - Flag::CreateFlag( - "sp_path", &sp_path, - "Path to the sentencepiece model file.", - Flag::kRequired), + Flag::CreateFlag("sp_path", &sp_path, + "Path to the sentencepiece model file.", + Flag::kRequired), Flag::CreateFlag( "zero-shot", &zero_shot, "Use zero-shot prompts instead of the default few-shot."), @@ -414,7 +413,8 @@ int Main(int argc, char *argv[]) { if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && backend) { - dataset.reset(new MmluGen(backend.get(), input_tfrecord, sp_path, zero_shot)); + dataset.reset( + new MmluGen(backend.get(), input_tfrecord, sp_path, zero_shot)); } // Adds to flag_list for showing help. flag_list.insert(flag_list.end(), dataset_flags.begin(), diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index 49d804d73..ef7b0bac4 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -108,8 +108,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( in->output_dir); break; case ::mlperf::mobile::DatasetConfig::MMLU: - for (auto setting : settings.benchmark_setting().custom_setting()) - { + for (auto setting : settings.benchmark_setting().custom_setting()) { if (setting.id() == "llm_tokenizer_path") sp_path = setting.value(); } dataset = std::make_unique<::mlperf::mobile::MmluGen>( diff --git a/flutter/cpp/mlperf_driver.cc b/flutter/cpp/mlperf_driver.cc index 343f627d1..6662d1ca5 100644 --- a/flutter/cpp/mlperf_driver.cc +++ b/flutter/cpp/mlperf_driver.cc @@ -28,9 +28,11 @@ limitations under the License. namespace mlperf { namespace mobile { -// A method to be called by the backend as soon as the first token is generated (only for token based benchmarks) +// A method to be called by the backend as soon as the first token is generated +// (only for token based benchmarks) static void FirstTokenCallback(void* context) { - auto ft_responses = *(reinterpret_cast*>(context)); + auto ft_responses = + *(reinterpret_cast*>(context)); ::mlperf::FirstTokenComplete(ft_responses.data(), ft_responses.size()); } @@ -56,9 +58,11 @@ void MlperfDriver::IssueQuery( // TODO maybe don't do these 2 lines for non token stuff // TODO figure out what this vector sample variable is ft_responses.clear(); - ft_responses.push_back({sample.back().index, reinterpret_cast(nullptr), 0}); + ft_responses.push_back( + {sample.back().index, reinterpret_cast(nullptr), 0}); - backend_->IssueQuery(&FirstTokenCallback, reinterpret_cast(&ft_responses)); + backend_->IssueQuery(&FirstTokenCallback, + reinterpret_cast(&ft_responses)); for (int b = 0; b < batch_; b++) { if (idx + b == samples.size()) break; // ignore extra data @@ -82,9 +86,11 @@ void MlperfDriver::IssueQuery( // TODO maybe don't do these 2 lines for non token stuff ft_responses.clear(); - ft_responses.push_back({sample.id, reinterpret_cast(nullptr), 0}); + ft_responses.push_back( + {sample.id, reinterpret_cast(nullptr), 0}); - backend_->IssueQuery(&FirstTokenCallback, reinterpret_cast(&ft_responses)); + backend_->IssueQuery(&FirstTokenCallback, + reinterpret_cast(&ft_responses)); // Report to mlperf. std::vector outputs = backend_->GetPredictedOutputs(); diff --git a/mobile_back_tflite/cpp/backend_tflite/pipeline.h b/mobile_back_tflite/cpp/backend_tflite/pipeline.h index ad82c002a..3694c8a0a 100644 --- a/mobile_back_tflite/cpp/backend_tflite/pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/pipeline.h @@ -41,8 +41,9 @@ class Pipeline { virtual const char *backend_name(mlperf_backend_ptr_t backend_ptr) = 0; // Run the inference for a sample. - virtual mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) = 0; + virtual mlperf_status_t backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) = 0; // Flush the staged queries immediately. virtual mlperf_status_t backend_flush_queries( diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc index bf60dddbf..5b8420a0c 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.cc @@ -425,7 +425,7 @@ const char *SingleModelPipeline::backend_name( // Run the inference for a sample. mlperf_status_t SingleModelPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void *context) { TFLiteBackendData *backend_data = (TFLiteBackendData *)backend_ptr; #ifdef MTK_TFLITE_NEURON_BACKEND diff --git a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h index 37b323b87..04755420e 100644 --- a/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/single_model_pipeline.h @@ -37,8 +37,9 @@ class SingleModelPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) override; + mlperf_status_t backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) override; mlperf_status_t backend_flush_queries( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h index b2dbdbcb4..6f3bad31e 100644 --- a/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/stable_diffusion_pipeline.h @@ -64,8 +64,9 @@ class StableDiffusionPipeline : public Pipeline { const char *backend_name(mlperf_backend_ptr_t backend_ptr) override; - mlperf_status_t backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) override; + mlperf_status_t backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) override; mlperf_status_t backend_flush_queries( mlperf_backend_ptr_t backend_ptr) override; diff --git a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc index 1bb6b96a5..9020bd4fd 100644 --- a/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc +++ b/mobile_back_tflite/cpp/backend_tflite/tflite_c.cc @@ -190,7 +190,9 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { // Run the inference for a sample. // callback and context are only used when running token based inferences (LLM). // In other cases they can be passed as nullptr -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) { return pipeline->backend_issue_query(backend_ptr, callback, context); } From 5a96013de34b452c5802b06da796c6f9cd39cd9d Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 6 Oct 2025 06:38:28 +0300 Subject: [PATCH 25/74] code cleanup / issue_query signature update to vendor backends --- flutter/cpp/backend.h | 3 +- mobile_back_apple/cpp/backend_coreml/main.cc | 4 ++- .../pixel_single_model_pipeline.cc | 2 +- .../cpp/backend_tflite/tflite_pixel.cc | 6 ++-- .../cpp/backend_mock_qti/qti_mock_c.cc | 4 ++- mobile_back_qti/cpp/backend_qti/acpitabl.h | 28 +++++++++---------- mobile_back_qti/cpp/backend_qti/cpuctrl.cc | 12 +++++--- mobile_back_qti/cpp/backend_qti/tflite_c.cc | 4 ++- mobile_back_qti/cpp/backend_qti/tflite_c.h | 3 +- .../samsung/lib/public/mbe_core/mbe_core.cc | 4 ++- .../cpp/backend_tflite/neuron/BUILD | 4 +++ 11 files changed, 47 insertions(+), 27 deletions(-) diff --git a/flutter/cpp/backend.h b/flutter/cpp/backend.h index 6a73aee0d..a1e5fba4e 100644 --- a/flutter/cpp/backend.h +++ b/flutter/cpp/backend.h @@ -45,7 +45,8 @@ class Backend { virtual const std::string& AcceleratorName() const = 0; // Run inference for a sample. Inputs is already set by SetInputs. - // TODO might be good to provide the callback and context along with the inputs if possible + // TODO might be good to provide the callback and context along with the + // inputs if possible virtual void IssueQuery(ft_callback callback, void* context) = 0; // Flush the staged queries immediately. diff --git a/mobile_back_apple/cpp/backend_coreml/main.cc b/mobile_back_apple/cpp/backend_coreml/main.cc index af753d566..d53918870 100644 --- a/mobile_back_apple/cpp/backend_coreml/main.cc +++ b/mobile_back_apple/cpp/backend_coreml/main.cc @@ -136,7 +136,9 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { } // Run the inference for a sample. -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) { CoreMLBackendData *backend_data = (CoreMLBackendData *)backend_ptr; if ([backend_data->coreMLExecutor issueQueries]) return MLPERF_SUCCESS; return MLPERF_FAILURE; diff --git a/mobile_back_pixel/cpp/backend_tflite/pixel_single_model_pipeline.cc b/mobile_back_pixel/cpp/backend_tflite/pixel_single_model_pipeline.cc index 1d44b411f..b14e0bef6 100644 --- a/mobile_back_pixel/cpp/backend_tflite/pixel_single_model_pipeline.cc +++ b/mobile_back_pixel/cpp/backend_tflite/pixel_single_model_pipeline.cc @@ -271,7 +271,7 @@ void SingleModelPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { // Run the inference for a sample. mlperf_status_t SingleModelPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr) { + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { TFLiteBackendData* backend_data = (TFLiteBackendData*)backend_ptr; auto task = [&backend_data](int index) -> TfLiteStatus { return TfLiteInterpreterInvoke(backend_data->interpreter[index]); diff --git a/mobile_back_pixel/cpp/backend_tflite/tflite_pixel.cc b/mobile_back_pixel/cpp/backend_tflite/tflite_pixel.cc index 476b7ae60..49d081970 100644 --- a/mobile_back_pixel/cpp/backend_tflite/tflite_pixel.cc +++ b/mobile_back_pixel/cpp/backend_tflite/tflite_pixel.cc @@ -95,8 +95,10 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { } // Run the inference for a sample. -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { - return pipeline->backend_issue_query(backend_ptr); +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) { + return pipeline->backend_issue_query(backend_ptr, callback, context); } // Flush the staged queries immediately. diff --git a/mobile_back_qti/cpp/backend_mock_qti/qti_mock_c.cc b/mobile_back_qti/cpp/backend_mock_qti/qti_mock_c.cc index fbf3d5292..bd8a649de 100644 --- a/mobile_back_qti/cpp/backend_mock_qti/qti_mock_c.cc +++ b/mobile_back_qti/cpp/backend_mock_qti/qti_mock_c.cc @@ -60,7 +60,9 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { } // Run the inference for a sample. -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void* context) { return MLPERF_FAILURE; } diff --git a/mobile_back_qti/cpp/backend_qti/acpitabl.h b/mobile_back_qti/cpp/backend_qti/acpitabl.h index f0e709c0f..28d96760a 100644 --- a/mobile_back_qti/cpp/backend_qti/acpitabl.h +++ b/mobile_back_qti/cpp/backend_qti/acpitabl.h @@ -264,7 +264,7 @@ typedef struct _ACPI_SRAT { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union typedef struct _ACPI_SRAT_ENTRY { UCHAR Type; @@ -435,7 +435,7 @@ typedef struct _ACPI_MPST { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union typedef struct _POWER_STATE_CHARACTERISTICS { union { @@ -934,10 +934,10 @@ typedef LOCAL_X2APIC_NMISOURCE UNALIGNED *PLOCAL_X2APIC_NMISOURCE; _COMPRESSED_ |= _AFF3_; \ } -#define UNCOMPRESS_MPIDR(_COMPRESSED_, _MPIDR_) \ - { \ - (_MPIDR_) = (ULONGLONG)(_COMPRESSED_)&0x00FFFFFFULL; \ - (_MPIDR_) |= ((ULONGLONG)(_COMPRESSED_)&0xFF000000ULL) << 8; \ +#define UNCOMPRESS_MPIDR(_COMPRESSED_, _MPIDR_) \ + { \ + (_MPIDR_) = (ULONGLONG)(_COMPRESSED_) & 0x00FFFFFFULL; \ + (_MPIDR_) |= ((ULONGLONG)(_COMPRESSED_) & 0xFF000000ULL) << 8; \ } typedef struct _PROCLOCALGIC { @@ -1852,7 +1852,7 @@ C_ASSERT(WAET_DEV_RTC_ENLIGHTENED == 1); #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union // // Top-level IORT table @@ -2155,7 +2155,7 @@ typedef struct _RHSA { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union typedef struct _DMARTABLE { USHORT Type; @@ -2653,7 +2653,7 @@ typedef struct _PCC_TABLE { #define BGRT_STATUS_DISPLAY_ROTATION 0x06 #define BGRT_STATUS_GET_DISPLAY_ROTATION(_Status_) \ - ((UCHAR)((ULONG)((_Status_)&BGRT_STATUS_DISPLAY_ROTATION) >> 1)) + ((UCHAR)((ULONG)((_Status_) & BGRT_STATUS_DISPLAY_ROTATION) >> 1)) typedef enum _BGRT_IMAGE_TYPE { BgrtImageTypeBitmap, @@ -3065,7 +3065,7 @@ typedef struct _ACPI_PLD_V2_BUFFER { // Color bits 8:31 (Red 8:15, Green 16:23, Blue 24:31) #define ACPI_PLD_MAKE_COLOR(r, g, b) \ - ((UINT32)(((r)&0xFF) | (((g)&0xFF) << 8) | (((b)&0xFF) << 16))) + ((UINT32)(((r) & 0xFF) | (((g) & 0xFF) << 8) | (((b) & 0xFF) << 16))) #define ACPI_PLD_COLOR_RED(c) ((BYTE)(((c) >> 0) & 0xFF)) #define ACPI_PLD_COLOR_GREEN(c) ((BYTE)(((c) >> 8) & 0xFF)) #define ACPI_PLD_COLOR_BLUE(c) ((BYTE)(((c) >> 16) & 0xFF)) @@ -3610,7 +3610,7 @@ typedef struct _NFIT_PLATFORM_CAPABILITIES { #endif #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int @@ -3646,7 +3646,7 @@ typedef struct _WSMT { #endif #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int @@ -3692,7 +3692,7 @@ typedef struct _LPIT { #endif #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int @@ -3844,7 +3844,7 @@ typedef struct _ACPI_PDTT { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #define HMAT_SIGNATURE 0x54414D48 // "HMAT" diff --git a/mobile_back_qti/cpp/backend_qti/cpuctrl.cc b/mobile_back_qti/cpp/backend_qti/cpuctrl.cc index bcc401e13..49c333b31 100644 --- a/mobile_back_qti/cpp/backend_qti/cpuctrl.cc +++ b/mobile_back_qti/cpp/backend_qti/cpuctrl.cc @@ -35,13 +35,17 @@ using namespace std::chrono; #define GET_AFFINITY(a, b) sched_getaffinity(gettid(), a, b) #else #define SET_AFFINITY(a, b) \ - {} + { \ + } #define GET_AFFINITY(a, b) \ - {} + { \ + } #define CPU_ZERO(a) \ - {} + { \ + } #define CPU_SET(a, b) \ - {} + { \ + } #endif static uint32_t soc_id_ = 0; diff --git a/mobile_back_qti/cpp/backend_qti/tflite_c.cc b/mobile_back_qti/cpp/backend_qti/tflite_c.cc index ec929cb17..8d63493ac 100644 --- a/mobile_back_qti/cpp/backend_qti/tflite_c.cc +++ b/mobile_back_qti/cpp/backend_qti/tflite_c.cc @@ -243,7 +243,9 @@ mlperf_backend_ptr_t tflite_backend_create( const char* model_path, mlperf_backend_configuration_t* configs) { return NULL; } -mlperf_status_t tflite_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t tflite_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void* context) { return MLPERF_SUCCESS; } mlperf_status_t tflite_backend_flush_queries(mlperf_backend_ptr_t backend_ptr) { diff --git a/mobile_back_qti/cpp/backend_qti/tflite_c.h b/mobile_back_qti/cpp/backend_qti/tflite_c.h index 8f287b721..2c29fe7aa 100644 --- a/mobile_back_qti/cpp/backend_qti/tflite_c.h +++ b/mobile_back_qti/cpp/backend_qti/tflite_c.h @@ -23,7 +23,8 @@ mlperf_backend_ptr_t tflite_backend_create( // Destroy the backend pointer and its data. void tflite_backend_delete(mlperf_backend_ptr_t backend_ptr); // Run the inference for a sample. -mlperf_status_t tflite_backend_issue_query(mlperf_backend_ptr_t backend_ptr); +mlperf_status_t tflite_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, void* context); // Flush the staged queries immediately. mlperf_status_t tflite_backend_flush_queries(mlperf_backend_ptr_t backend_ptr); // Return the number of inputs of the model. diff --git a/mobile_back_samsung/samsung/lib/public/mbe_core/mbe_core.cc b/mobile_back_samsung/samsung/lib/public/mbe_core/mbe_core.cc index 59ca19ab2..6f4fce544 100644 --- a/mobile_back_samsung/samsung/lib/public/mbe_core/mbe_core.cc +++ b/mobile_back_samsung/samsung/lib/public/mbe_core/mbe_core.cc @@ -213,7 +213,9 @@ mlperf_data_t mlperf_backend_get_output_type(mlperf_backend_ptr_t backend_ptr, return data; } -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) { mbe_core_holder *ptr = (mbe_core_holder *)backend_ptr; MLOGD("+ mlperf_backend_issue_query with ptr[%p]", ptr); intf_mlperf_status_t intf_status = ptr->issue_query_fp(); diff --git a/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD b/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD index e9e8cf9a7..82c3adc5d 100644 --- a/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD @@ -33,6 +33,7 @@ cc_library( srcs = [ "neuron_backend.cc", "//mobile_back_tflite/cpp/backend_tflite:sd_utils.cc", + "//mobile_back_tflite/cpp/backend_tflite:llm_pipeline.cc", "//mobile_back_tflite/cpp/backend_tflite:single_model_pipeline.cc", "//mobile_back_tflite/cpp/backend_tflite:stable_diffusion_invoker.cc", "//mobile_back_tflite/cpp/backend_tflite:stable_diffusion_pipeline.cc", @@ -45,6 +46,7 @@ cc_library( "neuron_backend.h", "neuron_builder.h", "tflite_settings_mtk.h", + "//mobile_back_tflite/cpp/backend_tflite:llm_pipeline.h", "//mobile_back_tflite/cpp/backend_tflite:pipeline.h", "//mobile_back_tflite/cpp/backend_tflite:sd_utils.h", "//mobile_back_tflite/cpp/backend_tflite:single_model_pipeline.h", @@ -77,6 +79,8 @@ cc_library( "@org_tensorflow//tensorflow/core:tflite_portable_logging", "@org_tensorflow//tensorflow/lite/c:c_api", "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/experimental/genai:genai_ops", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ] + select({ "@org_tensorflow//tensorflow:android": [ ":neuron_delegate", From 3a54a666b53b6bccc6c3c3bb5a4616add4d8ebc7 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 6 Oct 2025 07:20:44 +0300 Subject: [PATCH 26/74] signature update for QTI/Samsung backends --- mobile_back_qti/cpp/backend_qti/qti_c.cc | 7 +++++-- mobile_back_qti/cpp/backend_qti/tflite_c.cc | 4 +++- mobile_back_samsung/samsung/lib/include/type_interfaced.h | 3 +++ .../samsung/lib/public/include/mbe_core_holder.hpp | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mobile_back_qti/cpp/backend_qti/qti_c.cc b/mobile_back_qti/cpp/backend_qti/qti_c.cc index 08de9874c..d80caaec9 100644 --- a/mobile_back_qti/cpp/backend_qti/qti_c.cc +++ b/mobile_back_qti/cpp/backend_qti/qti_c.cc @@ -199,7 +199,9 @@ void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) { } // Run the inference for a sample. -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void *context) { mlperf_status_t ret = MLPERF_FAILURE; QTIBackendHelper *backend_data = (QTIBackendHelper *)backend_ptr; @@ -208,7 +210,8 @@ mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { auto start = high_resolution_clock::now(); #endif if (backend_data->isTflite_) { - return tflite_backend_issue_query(backend_data->tfliteBackend_); + return tflite_backend_issue_query(backend_data->tfliteBackend_, callback, + context); } if (backend_data->isStableDiffusion) { diff --git a/mobile_back_qti/cpp/backend_qti/tflite_c.cc b/mobile_back_qti/cpp/backend_qti/tflite_c.cc index 8d63493ac..abde79afd 100644 --- a/mobile_back_qti/cpp/backend_qti/tflite_c.cc +++ b/mobile_back_qti/cpp/backend_qti/tflite_c.cc @@ -145,7 +145,9 @@ void tflite_backend_delete(mlperf_backend_ptr_t backend_ptr) { } // Run the inference for a sample. -mlperf_status_t tflite_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t tflite_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void* context) { TFLiteBackendData* backend_data = (TFLiteBackendData*)backend_ptr; if (TfLiteInterpreterInvoke(backend_data->interpreter) != kTfLiteOk) { printf("Failed to run the inference"); diff --git a/mobile_back_samsung/samsung/lib/include/type_interfaced.h b/mobile_back_samsung/samsung/lib/include/type_interfaced.h index e79d53f6c..c7e2586d7 100755 --- a/mobile_back_samsung/samsung/lib/include/type_interfaced.h +++ b/mobile_back_samsung/samsung/lib/include/type_interfaced.h @@ -60,6 +60,9 @@ typedef struct { const char* values[kMaxMLPerfBackendConfigs_intf]; } intf_mlperf_backend_configuration_t; + +typedef void (*ft_callback)(void* context); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp b/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp index 3b1a64e0e..c21568fd1 100644 --- a/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp +++ b/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp @@ -44,7 +44,7 @@ class mbe_core_holder { using backend_get_output_t = std::add_pointer::type; using backend_issue_query_t = - std::add_pointer::type; + std::add_pointer::type; using backend_convert_inputs_t = std::add_pointer::type; using backend_delete_t = std::add_pointer::type; From 26e562bb9e42e2d0ae75a89313053f1880feba9e Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 6 Oct 2025 07:24:38 +0300 Subject: [PATCH 27/74] format --- mobile_back_samsung/samsung/lib/include/type_interfaced.h | 1 - 1 file changed, 1 deletion(-) mode change 100755 => 100644 mobile_back_samsung/samsung/lib/include/type_interfaced.h diff --git a/mobile_back_samsung/samsung/lib/include/type_interfaced.h b/mobile_back_samsung/samsung/lib/include/type_interfaced.h old mode 100755 new mode 100644 index c7e2586d7..da469e953 --- a/mobile_back_samsung/samsung/lib/include/type_interfaced.h +++ b/mobile_back_samsung/samsung/lib/include/type_interfaced.h @@ -60,7 +60,6 @@ typedef struct { const char* values[kMaxMLPerfBackendConfigs_intf]; } intf_mlperf_backend_configuration_t; - typedef void (*ft_callback)(void* context); #ifdef __cplusplus From d485523e064e1667d9da1fc9a85dedf10d128fe8 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 6 Oct 2025 08:39:42 +0300 Subject: [PATCH 28/74] formatted clang and bazel using docker based formatter --- flutter/cpp/c/backend_c.h | 3 +- flutter/cpp/datasets/BUILD | 2 +- flutter/cpp/datasets/mmlu_utils/BUILD | 6 ++-- mobile_back_qti/cpp/backend_qti/acpitabl.h | 28 +++++++++---------- mobile_back_qti/cpp/backend_qti/cpuctrl.cc | 12 +++----- .../cpp/backend_dummy/dummy_backend.cc | 4 ++- .../cpp/backend_tflite/neuron/BUILD | 2 +- 7 files changed, 28 insertions(+), 29 deletions(-) diff --git a/flutter/cpp/c/backend_c.h b/flutter/cpp/c/backend_c.h index 47d3c9bb5..0fd22169e 100644 --- a/flutter/cpp/c/backend_c.h +++ b/flutter/cpp/c/backend_c.h @@ -55,7 +55,8 @@ const char* mlperf_backend_name(mlperf_backend_ptr_t backend_ptr); void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr); // Run the inference for a sample. -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr); +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, void* context); // Flush the staged queries immediately. mlperf_status_t mlperf_backend_flush_queries(mlperf_backend_ptr_t backend_ptr); diff --git a/flutter/cpp/datasets/BUILD b/flutter/cpp/datasets/BUILD index dd739b0de..24314ea10 100644 --- a/flutter/cpp/datasets/BUILD +++ b/flutter/cpp/datasets/BUILD @@ -228,8 +228,8 @@ cc_library( "//flutter/cpp:mlperf_driver", "//flutter/cpp:utils", "//flutter/cpp/backends:external", - "//flutter/cpp/datasets/squad_utils", "//flutter/cpp/datasets/mmlu_utils", + "//flutter/cpp/datasets/squad_utils", "@com_google_sentencepiece//:sentencepiece_processor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_protobuf//:protobuf", diff --git a/flutter/cpp/datasets/mmlu_utils/BUILD b/flutter/cpp/datasets/mmlu_utils/BUILD index 7470937a7..dc1051e66 100644 --- a/flutter/cpp/datasets/mmlu_utils/BUILD +++ b/flutter/cpp/datasets/mmlu_utils/BUILD @@ -21,7 +21,7 @@ package( cc_library( name = "mmlu_utils", hdrs = [ - "sentencepiece_utils.h" + "sentencepiece_utils.h", ], copts = select({ "//flutter/android/commonlibs:use_asan": [ @@ -33,6 +33,6 @@ cc_library( "//conditions:default": [], }), deps = [ - "@com_google_sentencepiece//:sentencepiece_processor" - ] + "@com_google_sentencepiece//:sentencepiece_processor", + ], ) diff --git a/mobile_back_qti/cpp/backend_qti/acpitabl.h b/mobile_back_qti/cpp/backend_qti/acpitabl.h index 28d96760a..f0e709c0f 100644 --- a/mobile_back_qti/cpp/backend_qti/acpitabl.h +++ b/mobile_back_qti/cpp/backend_qti/acpitabl.h @@ -264,7 +264,7 @@ typedef struct _ACPI_SRAT { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union typedef struct _ACPI_SRAT_ENTRY { UCHAR Type; @@ -435,7 +435,7 @@ typedef struct _ACPI_MPST { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union typedef struct _POWER_STATE_CHARACTERISTICS { union { @@ -934,10 +934,10 @@ typedef LOCAL_X2APIC_NMISOURCE UNALIGNED *PLOCAL_X2APIC_NMISOURCE; _COMPRESSED_ |= _AFF3_; \ } -#define UNCOMPRESS_MPIDR(_COMPRESSED_, _MPIDR_) \ - { \ - (_MPIDR_) = (ULONGLONG)(_COMPRESSED_) & 0x00FFFFFFULL; \ - (_MPIDR_) |= ((ULONGLONG)(_COMPRESSED_) & 0xFF000000ULL) << 8; \ +#define UNCOMPRESS_MPIDR(_COMPRESSED_, _MPIDR_) \ + { \ + (_MPIDR_) = (ULONGLONG)(_COMPRESSED_)&0x00FFFFFFULL; \ + (_MPIDR_) |= ((ULONGLONG)(_COMPRESSED_)&0xFF000000ULL) << 8; \ } typedef struct _PROCLOCALGIC { @@ -1852,7 +1852,7 @@ C_ASSERT(WAET_DEV_RTC_ENLIGHTENED == 1); #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union // // Top-level IORT table @@ -2155,7 +2155,7 @@ typedef struct _RHSA { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union typedef struct _DMARTABLE { USHORT Type; @@ -2653,7 +2653,7 @@ typedef struct _PCC_TABLE { #define BGRT_STATUS_DISPLAY_ROTATION 0x06 #define BGRT_STATUS_GET_DISPLAY_ROTATION(_Status_) \ - ((UCHAR)((ULONG)((_Status_) & BGRT_STATUS_DISPLAY_ROTATION) >> 1)) + ((UCHAR)((ULONG)((_Status_)&BGRT_STATUS_DISPLAY_ROTATION) >> 1)) typedef enum _BGRT_IMAGE_TYPE { BgrtImageTypeBitmap, @@ -3065,7 +3065,7 @@ typedef struct _ACPI_PLD_V2_BUFFER { // Color bits 8:31 (Red 8:15, Green 16:23, Blue 24:31) #define ACPI_PLD_MAKE_COLOR(r, g, b) \ - ((UINT32)(((r) & 0xFF) | (((g) & 0xFF) << 8) | (((b) & 0xFF) << 16))) + ((UINT32)(((r)&0xFF) | (((g)&0xFF) << 8) | (((b)&0xFF) << 16))) #define ACPI_PLD_COLOR_RED(c) ((BYTE)(((c) >> 0) & 0xFF)) #define ACPI_PLD_COLOR_GREEN(c) ((BYTE)(((c) >> 8) & 0xFF)) #define ACPI_PLD_COLOR_BLUE(c) ((BYTE)(((c) >> 16) & 0xFF)) @@ -3610,7 +3610,7 @@ typedef struct _NFIT_PLATFORM_CAPABILITIES { #endif #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int @@ -3646,7 +3646,7 @@ typedef struct _WSMT { #endif #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int @@ -3692,7 +3692,7 @@ typedef struct _LPIT { #endif #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int @@ -3844,7 +3844,7 @@ typedef struct _ACPI_PDTT { #pragma warning(disable : 4214) // nonstandard extension used : bit field types // other than int #pragma warning( \ - disable : 4201) // nonstandard extension used : nameless struct/union + disable : 4201) // nonstandard extension used : nameless struct/union #define HMAT_SIGNATURE 0x54414D48 // "HMAT" diff --git a/mobile_back_qti/cpp/backend_qti/cpuctrl.cc b/mobile_back_qti/cpp/backend_qti/cpuctrl.cc index 49c333b31..bcc401e13 100644 --- a/mobile_back_qti/cpp/backend_qti/cpuctrl.cc +++ b/mobile_back_qti/cpp/backend_qti/cpuctrl.cc @@ -35,17 +35,13 @@ using namespace std::chrono; #define GET_AFFINITY(a, b) sched_getaffinity(gettid(), a, b) #else #define SET_AFFINITY(a, b) \ - { \ - } + {} #define GET_AFFINITY(a, b) \ - { \ - } + {} #define CPU_ZERO(a) \ - { \ - } + {} #define CPU_SET(a, b) \ - { \ - } + {} #endif static uint32_t soc_id_ = 0; diff --git a/mobile_back_tflite/cpp/backend_dummy/dummy_backend.cc b/mobile_back_tflite/cpp/backend_dummy/dummy_backend.cc index a6d455db2..4d60cfea7 100644 --- a/mobile_back_tflite/cpp/backend_dummy/dummy_backend.cc +++ b/mobile_back_tflite/cpp/backend_dummy/dummy_backend.cc @@ -25,7 +25,9 @@ const char* mlperf_backend_name(mlperf_backend_ptr_t backend_ptr) { return ""; } void mlperf_backend_delete(mlperf_backend_ptr_t backend_ptr) {} -mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr) { +mlperf_status_t mlperf_backend_issue_query(mlperf_backend_ptr_t backend_ptr, + ft_callback callback, + void* context) { return MLPERF_FAILURE; } diff --git a/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD b/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD index 82c3adc5d..832426ece 100644 --- a/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/neuron/BUILD @@ -32,8 +32,8 @@ cc_library( name = "tflite_neuron_c", srcs = [ "neuron_backend.cc", - "//mobile_back_tflite/cpp/backend_tflite:sd_utils.cc", "//mobile_back_tflite/cpp/backend_tflite:llm_pipeline.cc", + "//mobile_back_tflite/cpp/backend_tflite:sd_utils.cc", "//mobile_back_tflite/cpp/backend_tflite:single_model_pipeline.cc", "//mobile_back_tflite/cpp/backend_tflite:stable_diffusion_invoker.cc", "//mobile_back_tflite/cpp/backend_tflite:stable_diffusion_pipeline.cc", From 24cf047fe4914cd98b58da1839994e282ffa9537 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 6 Oct 2025 09:28:15 +0300 Subject: [PATCH 29/74] reverted issue_query change for samsung + bazel formatting --- flutter/cpp/datasets/BUILD | 2 +- .../samsung/lib/public/include/mbe_core_holder.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flutter/cpp/datasets/BUILD b/flutter/cpp/datasets/BUILD index 24314ea10..6b1c372df 100644 --- a/flutter/cpp/datasets/BUILD +++ b/flutter/cpp/datasets/BUILD @@ -230,9 +230,9 @@ cc_library( "//flutter/cpp/backends:external", "//flutter/cpp/datasets/mmlu_utils", "//flutter/cpp/datasets/squad_utils", - "@com_google_sentencepiece//:sentencepiece_processor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_protobuf//:protobuf", + "@com_google_sentencepiece//:sentencepiece_processor", "@org_tensorflow//tensorflow/lite/tools/evaluation:utils", "@org_tensorflow//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", ], diff --git a/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp b/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp index c21568fd1..3b1a64e0e 100644 --- a/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp +++ b/mobile_back_samsung/samsung/lib/public/include/mbe_core_holder.hpp @@ -44,7 +44,7 @@ class mbe_core_holder { using backend_get_output_t = std::add_pointer::type; using backend_issue_query_t = - std::add_pointer::type; + std::add_pointer::type; using backend_convert_inputs_t = std::add_pointer::type; using backend_delete_t = std::add_pointer::type; From 30b64641d153390cac9ab95a053c66f7f5abb8fc Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 7 Oct 2025 05:39:54 +0300 Subject: [PATCH 30/74] fix for MSVC C7555 error --- mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index a6ee4434f..5b2d12369 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -334,9 +334,9 @@ kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter *interpreter) { void LLMPipeline::PrepareRunner(tflite::SignatureRunner *runner, kv_cache_t &kv_cache) { for (auto &[name, cache] : kv_cache) { - TfLiteCustomAllocation allocation = { - .data = static_cast(cache.data()), - .bytes = cache.size() * sizeof(float)}; + TfLiteCustomAllocation allocation = {}; + allocation.data = static_cast(cache.data()); + allocation.bytes = cache.size() * sizeof(float); // Both input and output tensors are set to the same buffer. Not all // delegates support this in-place update. For those cases, we need to do // a ping-pong buffer and update the pointers between inference calls. From c294784f2facf9ffaa43908c24042faf4e286a21 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 7 Oct 2025 05:46:40 +0300 Subject: [PATCH 31/74] rough IFEval implementation using llm_instruction benchmark --- flutter/cpp/binary/BUILD | 1 + flutter/cpp/binary/main.cc | 31 + flutter/cpp/datasets/BUILD | 34 + flutter/cpp/datasets/ifeval.cc | 488 +++++++++++++ flutter/cpp/datasets/ifeval.h | 90 +++ flutter/cpp/datasets/ifeval_utils/BUILD | 38 + flutter/cpp/datasets/ifeval_utils/common.h | 131 ++++ .../ifeval_utils/generate_tfrecords.py | 140 ++++ flutter/cpp/datasets/ifeval_utils/types.h | 683 ++++++++++++++++++ flutter/cpp/proto/mlperf_task.proto | 3 +- .../tflite_settings_android.pbtxt | 23 +- .../cpp/backend_tflite/llm_pipeline.h | 2 +- 12 files changed, 1660 insertions(+), 4 deletions(-) create mode 100644 flutter/cpp/datasets/ifeval.cc create mode 100644 flutter/cpp/datasets/ifeval.h create mode 100644 flutter/cpp/datasets/ifeval_utils/BUILD create mode 100644 flutter/cpp/datasets/ifeval_utils/common.h create mode 100644 flutter/cpp/datasets/ifeval_utils/generate_tfrecords.py create mode 100644 flutter/cpp/datasets/ifeval_utils/types.h diff --git a/flutter/cpp/binary/BUILD b/flutter/cpp/binary/BUILD index 595421343..da49ed31d 100644 --- a/flutter/cpp/binary/BUILD +++ b/flutter/cpp/binary/BUILD @@ -54,6 +54,7 @@ cc_binary( "//flutter/cpp/datasets:ade20k", "//flutter/cpp/datasets:coco", "//flutter/cpp/datasets:coco_gen", + "//flutter/cpp/datasets:ifeval", "//flutter/cpp/datasets:imagenet", "//flutter/cpp/datasets:mmlu_gen", "//flutter/cpp/datasets:snu_sr", diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index 24e5cc4d2..d8510197c 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -25,6 +25,7 @@ limitations under the License. #include "flutter/cpp/datasets/ade20k.h" #include "flutter/cpp/datasets/coco.h" #include "flutter/cpp/datasets/coco_gen.h" +#include "flutter/cpp/datasets/ifeval.h" #include "flutter/cpp/datasets/imagenet.h" #include "flutter/cpp/datasets/mmlu_gen.h" #include "flutter/cpp/datasets/snu_sr.h" @@ -70,6 +71,8 @@ DatasetConfig::DatasetType Str2DatasetType(absl::string_view name) { return DatasetConfig::COCOGEN; } else if (absl::EqualsIgnoreCase(name, "MMLU")) { return DatasetConfig::MMLU; + } else if (absl::EqualsIgnoreCase(name, "IFEVAL")) { + return DatasetConfig::IFEVAL; } else if (absl::EqualsIgnoreCase(name, "DUMMY")) { return DatasetConfig::NONE; } else { @@ -91,6 +94,8 @@ DatasetConfig::DatasetType BenchmarkId2DatasetType(absl::string_view name) { return DatasetConfig::SNUSR; } else if (absl::StartsWith(name, "stable_diffusion")) { return DatasetConfig::COCOGEN; + } else if (absl::StartsWith(name, "llm_instruction")) { + return DatasetConfig::IFEVAL; } else if (absl::StartsWith(name, "llm")) { return DatasetConfig::MMLU; } else { @@ -420,6 +425,32 @@ int Main(int argc, char *argv[]) { flag_list.insert(flag_list.end(), dataset_flags.begin(), dataset_flags.end()); } break; + case DatasetConfig::IFEVAL: { + bool loose_follow = false; + LOG(INFO) << "IFEval dataset for LLM benchmark"; + std::string input_tfrecord, sp_path = ""; + std::vector dataset_flags{ + Flag::CreateFlag( + "input_tfrecord", &input_tfrecord, + "Path to the tfrecord file containing inputs for the model.", + Flag::kRequired), + Flag::CreateFlag("sp_path", &sp_path, + "Path to the sentencepiece model file.", + Flag::kRequired), + Flag::CreateFlag("loose-follow", &loose_follow, + "Whether to loosely check if the instructions are " + "being followed"), + }; + + if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && + backend) { + dataset.reset( + new IFEval(backend.get(), input_tfrecord, sp_path, loose_follow)); + } + // Adds to flag_list for showing help. + flag_list.insert(flag_list.end(), dataset_flags.begin(), + dataset_flags.end()); + } break; case DatasetConfig::NONE: default: break; diff --git a/flutter/cpp/datasets/BUILD b/flutter/cpp/datasets/BUILD index 6b1c372df..1a2a37477 100644 --- a/flutter/cpp/datasets/BUILD +++ b/flutter/cpp/datasets/BUILD @@ -237,3 +237,37 @@ cc_library( "@org_tensorflow//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", ], ) + +cc_library( + name = "ifeval", + srcs = [ + "ifeval.cc", + ], + hdrs = [ + "ifeval.h", + "utils.h", + ], + copts = tflite_copts() + select({ + "//flutter/android/commonlibs:use_asan": [ + "-fsanitize=address", + "-g", + "-O1", + "-fno-omit-frame-pointer", + ], + "//conditions:default": [], + }), + deps = [ + ":allocator", + "//flutter/cpp:mlperf_driver", + "//flutter/cpp:utils", + "//flutter/cpp/backends:external", + "//flutter/cpp/datasets/ifeval_utils", + "//flutter/cpp/datasets/mmlu_utils", + "//flutter/cpp/datasets/squad_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_protobuf//:protobuf", + "@com_google_sentencepiece//:sentencepiece_processor", + "@org_tensorflow//tensorflow/lite/tools/evaluation:utils", + "@org_tensorflow//tensorflow/lite/tools/evaluation/proto:evaluation_stages_cc_proto", + ], +) diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc new file mode 100644 index 000000000..ef31259b7 --- /dev/null +++ b/flutter/cpp/datasets/ifeval.cc @@ -0,0 +1,488 @@ +#include "flutter/cpp/datasets/ifeval.h" + +#include +#include + +#include "flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature_util.h" + +namespace mlperf { +namespace mobile { + +IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, + const std::string& sp_path, bool loose_follow) + : sample_reader_(input_tfrecord), + loose_follow_(loose_follow), + Dataset(backend) { + sp_processor = std::unique_ptr( + LoadSentencePieceProcessor(sp_path)); + start_token_id = sp_processor->PieceToId(start_token); + end_token_id = sp_processor->PieceToId(end_token); + + // Load all TFRecord samples into memory + // NOTE this can be moved to LoadSamplesToRam, but will cause delays between + // queries due to IO reads + for (size_t i = 0; i < sample_reader_.Size(); i++) { + tensorflow::tstring record = sample_reader_.ReadRecord(i); + tensorflow::Example example; + example.ParseFromString(record); + int key = tensorflow::GetFeatureValues("key", example).Get(0); + std::string prompt = + tensorflow::GetFeatureValues("prompt", example).Get(0); + auto instructions = BuildInstructions(example); + + std::vector input_tokens; + sp_processor->Encode(prompt.c_str(), &input_tokens).ok(); + + input_tokens.insert(input_tokens.begin(), start_token_id); + + auto sample = std::make_unique(); + sample->key = key; + sample->prompt = prompt; + sample->input_tokens = input_tokens; + sample->instructions = std::move(instructions); + + samples_.push_back(std::move(sample)); + sample_output_token_counts_.push_back(0); + } +} + +void IFEval::LoadSamplesToRam(const std::vector& samples) { + for (auto id : samples) { + loaded_sample_ids_.insert(id); + } +} + +void IFEval::UnloadSamplesFromRam( + const std::vector& samples) { + for (auto id : samples) { + loaded_sample_ids_.erase(id); + } +} + +std::vector IFEval::GetData(int sample_idx) { + std::vector data; + + if (sample_idx < samples_.size()) { + data.push_back(reinterpret_cast( + const_cast*>(&(samples_[sample_idx]->input_tokens)))); + data.push_back(reinterpret_cast(const_cast(&end_token_id))); + } + return data; +} + +std::vector IFEval::ProcessOutput(const int sample_idx, + const std::vector& outputs) { + if (sample_idx >= samples_.size() || outputs.empty()) return {0}; + + const auto& output_tokens = + *(reinterpret_cast*>(outputs[0])); + + LOG(INFO) << '[' + << std::accumulate(std::next(output_tokens.begin()), + output_tokens.end(), + std::to_string(output_tokens[0]), + [](std::string a, int b) { + return std::move(a) + ", " + std::to_string(b); + }) + << "]\n"; + + sample_output_token_counts_[sample_idx] = output_tokens.size(); + + std::string prediction; + sp_processor->Decode(output_tokens, &prediction).ok(); + + LOG(INFO) << "output(" << std::to_string(sample_idx) << "): " << prediction + << std::endl; + + bool is_correct = true; // Automatically pass samples with no instructions. + std::vector groups; + for (const auto& instruction : samples_[sample_idx]->instructions) { + is_correct &= instruction->IsFollowed(prediction, loose_follow_); + groups.emplace_back(instruction->Group()); + } + + for (auto group : groups) ProcessResult(group, is_correct); + + return {static_cast(is_correct)}; +} + +int64_t IFEval::GetOutputTokenCount(const int sample_idx) { + return sample_output_token_counts_[sample_idx]; +} + +bool IFEval::HasAccuracy() { return true; } + +float IFEval::ComputeAccuracy() { + uint16_t correct_sum; + uint16_t total_sum; + + correct_sum += accuracy.change_case_correct; + correct_sum += accuracy.combination_correct; + correct_sum += accuracy.detectable_content_correct; + correct_sum += accuracy.detectable_format_correct; + correct_sum += accuracy.keywords_correct; + correct_sum += accuracy.language_correct; + correct_sum += accuracy.length_constraints_correct; + correct_sum += accuracy.punctuation_correct; + correct_sum += accuracy.startend_correct; + + total_sum += accuracy.change_case_total; + total_sum += accuracy.combination_total; + total_sum += accuracy.detectable_content_total; + total_sum += accuracy.detectable_format_total; + total_sum += accuracy.keywords_total; + total_sum += accuracy.language_total; + total_sum += accuracy.length_constraints_total; + total_sum += accuracy.punctuation_total; + total_sum += accuracy.startend_total; + + return total_sum > 0 ? static_cast(correct_sum) / total_sum : 0.0f; +} + +std::string IFEval::ComputeAccuracyString() { + float acc = ComputeAccuracy(); + return "Accuracy: " + std::to_string(acc * 100.0f) + "%"; +} + +inline std::vector> +IFEval::BuildInstructions(const tensorflow::Example& ex) { + std::vector> out; + + // ---- helpers (local) ---- + auto parse_relation = [](const std::string& s) -> ifeval::Relation { + return (s == "at least") ? ifeval::Relation::AT_LEAST + : ifeval::Relation::LESS_THAN; + }; + + auto add = [&](auto ptr) { out.emplace_back(std::move(ptr)); }; + + auto get_strs = [&](const std::string& key, + std::vector* vals) -> bool { + const auto& sfield = tensorflow::GetFeatureValues(key, ex); + std::vector svals(sfield.begin(), sfield.end()); + *vals = std::move(svals); + return true; + }; + auto get_ints = [&](const std::string& key, + std::vector* vals) -> bool { + const auto& ifield = tensorflow::GetFeatureValues(key, ex); + std::vector ivals(ifield.begin(), ifield.end()); + *vals = std::move(ivals); + return true; + }; + auto get_str = [&](const std::string& key, std::string* val) -> bool { + std::vector tmp; + if (!get_strs(key, &tmp) || tmp.empty()) return false; + *val = std::move(tmp[0]); + return true; + }; + auto get_int = [&](const std::string& key, int* val) -> bool { + std::vector tmp; + if (!get_ints(key, &tmp) || tmp.empty()) return false; + *val = static_cast(tmp[0]); + return true; + }; + + // Read instruction_id_list (bytes_list of strings) without touching + // ex.features().feature() + const auto& id_field = + tensorflow::GetFeatureValues("instruction_id_list", ex); + std::vector ids(id_field.begin(), id_field.end()); + if (ids.empty()) return out; + + // Enum for switch (one case per instruction kind) + enum class Kind { + kCapitalWordFrequency, + kEnglishCapital, + kEnglishLowercase, + kRepeatPrompt, + kTwoResponses, + kNumberPlaceholders, + kPostscript, + kConstrainedResponse, + kJsonFormat, + kMultipleSections, + kNumberBulletLists, + kNumberHighlightedSections, + kTitle, + kExistence, + kForbiddenWords, + kFrequency, + kLetterFrequency, + kResponseLanguage, + kNthParagraphFirstWord, + kNumberParagraphs, + kNumberSentences, + kNumberWords, + kNoComma, + kEndChecker, + kQuotation, + kUnknown + }; + + auto to_kind = [](const std::string& id) -> Kind { + auto colon = id.find(':'); + std::string name = (colon == std::string::npos) ? id : id.substr(colon + 1); + if (name == "capital_word_frequency") return Kind::kCapitalWordFrequency; + if (name == "english_capital") return Kind::kEnglishCapital; + if (name == "english_lowercase") return Kind::kEnglishLowercase; + if (name == "repeat_prompt") return Kind::kRepeatPrompt; + if (name == "two_responses") return Kind::kTwoResponses; + if (name == "number_placeholders") return Kind::kNumberPlaceholders; + if (name == "postscript") return Kind::kPostscript; + if (name == "constrained_response") return Kind::kConstrainedResponse; + if (name == "json_format") return Kind::kJsonFormat; + if (name == "multiple_sections") return Kind::kMultipleSections; + if (name == "number_bullet_lists") return Kind::kNumberBulletLists; + if (name == "number_highlighted_sections") + return Kind::kNumberHighlightedSections; + if (name == "title") return Kind::kTitle; + if (name == "existence") return Kind::kExistence; + if (name == "forbidden_words") return Kind::kForbiddenWords; + if (name == "frequency") return Kind::kFrequency; + if (name == "letter_frequency") return Kind::kLetterFrequency; + if (name == "response_language") return Kind::kResponseLanguage; + if (name == "nth_paragraph_first_word") return Kind::kNthParagraphFirstWord; + if (name == "number_paragraphs") return Kind::kNumberParagraphs; + if (name == "number_sentences") return Kind::kNumberSentences; + if (name == "number_words") return Kind::kNumberWords; + if (name == "no_comma") return Kind::kNoComma; + if (name == "end_checker") return Kind::kEndChecker; + if (name == "quotation") return Kind::kQuotation; + return Kind::kUnknown; + }; + + // Build each instruction from kwargs//* using + // tensorflow::GetFeatureValues(ex, key, &vec) + for (int i = 0; i < static_cast(ids.size()); ++i) { + const std::string& id = ids[i]; + const Kind kind = to_kind(id); + + auto K = [&](const std::string& key) { + return "kwargs/" + std::to_string(i) + "/" + key; + }; + + switch (kind) { + case Kind::kCapitalWordFrequency: { + int pct = 0; + std::string rel; + get_int(K("capital_frequency"), &pct); + get_str(K("capital_relation"), &rel); + add(std::make_unique( + pct, parse_relation(rel))); + break; + } + case Kind::kEnglishCapital: { + add(std::make_unique()); + break; + } + case Kind::kEnglishLowercase: { + add(std::make_unique()); + break; + } + case Kind::kRepeatPrompt: { + std::string p; + get_str(K("prompt_to_repeat"), &p); + add(std::make_unique(p)); + break; + } + case Kind::kTwoResponses: { + add(std::make_unique()); + break; + } + case Kind::kNumberPlaceholders: { + int n = 0; + get_int(K("num_placeholders"), &n); + add(std::make_unique(n)); + break; + } + case Kind::kPostscript: { + std::string m; + get_str(K("postscript_marker"), &m); + add(std::make_unique(m)); + break; + } + case Kind::kConstrainedResponse: { + add(std::make_unique()); + break; + } + case Kind::kJsonFormat: { + add(std::make_unique()); + break; + } + case Kind::kMultipleSections: { + int n = 0; + std::string sep; + get_int(K("num_sections"), &n); + get_str(K("section_spliter"), &sep); + add(std::make_unique(n, sep)); + break; + } + case Kind::kNumberBulletLists: { + int n = 0; + get_int(K("num_bullets"), &n); + add(std::make_unique(n)); + break; + } + case Kind::kNumberHighlightedSections: { + int n = 0; + get_int(K("num_highlights"), &n); + add(std::make_unique(n)); + break; + } + case Kind::kTitle: { + add(std::make_unique()); + break; + } + case Kind::kExistence: { + std::vector kws; + get_strs(K("keywords"), &kws); + add(std::make_unique(kws)); + break; + } + case Kind::kForbiddenWords: { + std::vector bad; + get_strs(K("forbidden_words"), &bad); + add(std::make_unique(bad)); + break; + } + case Kind::kFrequency: { + int n = 0; + std::string kw, rel; + get_int(K("frequency"), &n); + get_str(K("keyword"), &kw); + get_str(K("relation"), &rel); + add(std::make_unique(n, kw, parse_relation(rel))); + break; + } + case Kind::kLetterFrequency: { + int n = 0; + std::string letter, rel; + get_int(K("let_frequency"), &n); + get_str(K("letter"), &letter); + get_str(K("let_relation"), &rel); + char ch = letter.empty() ? 'a' : letter[0]; + add(std::make_unique(n, ch, + parse_relation(rel))); + break; + } + case Kind::kResponseLanguage: { + std::string lang; + get_str(K("language"), &lang); + add(std::make_unique(lang)); + break; + } + case Kind::kNthParagraphFirstWord: { + int nth = 0, total = 0; + std::string fw; + get_int(K("nth_paragraph"), &nth); + get_int(K("num_paragraphs"), &total); + get_str(K("first_word"), &fw); + add(std::make_unique(nth, fw, total)); + break; + } + case Kind::kNumberParagraphs: { + int n = 0; + get_int(K("num_paragraphs"), &n); + add(std::make_unique(n)); + break; + } + case Kind::kNumberSentences: { + int n = 0; + std::string rel; + get_int(K("num_sentences"), &n); + get_str(K("relation"), &rel); + add(std::make_unique(n, parse_relation(rel))); + break; + } + case Kind::kNumberWords: { + int n = 0; + std::string rel; + get_int(K("num_words"), &n); + get_str(K("relation"), &rel); + add(std::make_unique(n, parse_relation(rel))); + break; + } + case Kind::kNoComma: { + add(std::make_unique()); + break; + } + case Kind::kEndChecker: { + std::string end; + get_str(K("end_phrase"), &end); + add(std::make_unique(end)); + break; + } + case Kind::kQuotation: { + add(std::make_unique()); + break; + } + case Kind::kUnknown: + default: { + // Unknown instruction id: skip (or handle as needed) + break; + } + } + } + + return out; +} + +inline void IFEval::ProcessResult(ifeval::InstructionGroup group, + bool is_correct) { + uint8_t correct_value = is_correct ? 1 : 0; + switch (group) { + case ifeval::InstructionGroup::CHANGE_CASE: + accuracy.change_case_correct += correct_value; + accuracy.change_case_total++; + break; + + case ifeval::InstructionGroup::COMBINATION: + accuracy.combination_correct += correct_value; + accuracy.combination_total++; + break; + + case ifeval::InstructionGroup::DETECTABLE_CONTENT: + accuracy.detectable_content_correct += correct_value; + accuracy.detectable_content_total++; + break; + + case ifeval::InstructionGroup::DETECTABLE_FORMAT: + accuracy.detectable_format_correct += correct_value; + accuracy.detectable_format_total++; + break; + + case ifeval::InstructionGroup::KEYWORDS: + accuracy.keywords_correct += correct_value; + accuracy.keywords_total++; + break; + + case ifeval::InstructionGroup::LANGUAGE: + accuracy.language_correct += correct_value; + accuracy.language_total++; + break; + + case ifeval::InstructionGroup::LENGTH_CONSTRAINTS: + accuracy.length_constraints_correct += correct_value; + accuracy.length_constraints_total++; + break; + + case ifeval::InstructionGroup::PUNCTUATION: + accuracy.punctuation_correct += correct_value; + accuracy.punctuation_total++; + break; + + case ifeval::InstructionGroup::STARTEND: + accuracy.startend_correct += correct_value; + accuracy.startend_total++; + break; + + default: + break; + } +} + +} // namespace mobile +} // namespace mlperf diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h new file mode 100644 index 000000000..80866aad5 --- /dev/null +++ b/flutter/cpp/datasets/ifeval.h @@ -0,0 +1,90 @@ +#ifndef MLPERF_DATASETS_IFEVAL_H_ +#define MLPERF_DATASETS_IFEVAL_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "flutter/cpp/dataset.h" +#include "flutter/cpp/datasets/ifeval_utils/types.h" +#include "flutter/cpp/datasets/squad_utils/tfrecord_reader.h" +#include "src/sentencepiece_processor.h" +#include "tensorflow/core/example/example.pb.h" + +namespace mlperf { +namespace mobile { +namespace ifeval { +struct GroupAccuracy { + size_t change_case_correct = 0, combination_correct = 0, + detectable_content_correct = 0, detectable_format_correct = 0, + keywords_correct = 0, language_correct = 0, + length_constraints_correct = 0, punctuation_correct = 0, + startend_correct = 0; + size_t change_case_total = 0, combination_total = 0, + detectable_content_total = 0, detectable_format_total = 0, + keywords_total = 0, language_total = 0, length_constraints_total = 0, + punctuation_total = 0, startend_total = 0; +}; +} // namespace ifeval +class IFEval : public Dataset { + public: + IFEval(Backend* backend, const std::string& input_tfrecord, + const std::string& sp_path, bool loose_follow); + + const std::string& Name() override { return name_; } + + size_t TotalSampleCount() override { return samples_.size(); } + + size_t PerformanceSampleCount() override { return 1; } + + void LoadSamplesToRam(const std::vector& samples) override; + + void UnloadSamplesFromRam( + const std::vector& samples) override; + + std::vector GetData(int sample_idx) override; + + std::vector ProcessOutput( + const int sample_idx, const std::vector& outputs) override; + + int64_t GetOutputTokenCount(const int sample_idx) override; + + bool HasAccuracy() override; + + float ComputeAccuracy() override; + + std::string ComputeAccuracyString() override; + + inline std::vector> BuildInstructions( + const tensorflow::Example& ex); + + inline void ProcessResult(ifeval::InstructionGroup group, bool is_correct); + + private: + const std::string name_ = "IFEval"; + + TFRecordReader sample_reader_; + + std::vector> samples_; + std::vector sample_output_token_counts_; + std::set loaded_sample_ids_; + std::unique_ptr sp_processor; + + ifeval::GroupAccuracy accuracy; + bool loose_follow_; + + std::string start_token = ""; + std::string end_token = ""; + int start_token_id; + int end_token_id; +}; + +} // namespace mobile +} // namespace mlperf + +#endif // MLPERF_DATASETS_IFEVAL_H_ diff --git a/flutter/cpp/datasets/ifeval_utils/BUILD b/flutter/cpp/datasets/ifeval_utils/BUILD new file mode 100644 index 000000000..dde8f8a72 --- /dev/null +++ b/flutter/cpp/datasets/ifeval_utils/BUILD @@ -0,0 +1,38 @@ +# Copyright 2025 The MLPerf Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "ifeval_utils", + hdrs = [ + "common.h", + "types.h", + ], + copts = select({ + "//flutter/android/commonlibs:use_asan": [ + "-fsanitize=address", + "-g", + "-O1", + "-fno-omit-frame-pointer", + ], + "//conditions:default": [], + }), + deps = [ + ], +) diff --git a/flutter/cpp/datasets/ifeval_utils/common.h b/flutter/cpp/datasets/ifeval_utils/common.h new file mode 100644 index 000000000..f1bcdfe03 --- /dev/null +++ b/flutter/cpp/datasets/ifeval_utils/common.h @@ -0,0 +1,131 @@ +#ifndef MLPERF_DATASETS_IFEVAL_UTILS_COMMON_H_ +#define MLPERF_DATASETS_IFEVAL_UTILS_COMMON_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace mlperf { +namespace mobile { +namespace ifeval { + +inline std::string ltrim(std::string s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch); + })); + return s; +} +inline std::string rtrim(std::string s) { + s.erase(std::find_if(s.rbegin(), s.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + s.end()); + return s; +} +inline std::string trim(std::string s) { return rtrim(ltrim(std::move(s))); } + +inline std::string tolower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c) { return std::tolower(c); }); + return s; +} + +inline bool ends_with(const std::string& s, const std::string& suf) { + if (s.size() < suf.size()) return false; + std::string a = tolower(s.substr(s.size() - suf.size())); + std::string b = tolower(suf); + return a == b; +} + +inline bool contains_string(const std::string& text, + const std::string& substring) { + std::string h = tolower(text), n = tolower(substring); + return h.find(n) != std::string::npos; +} + +inline bool contains_word(const std::string& text, const std::string& word) { + std::regex rx("\\b" + word + "\\b", std::regex::icase); + return std::regex_search(text.begin(), text.end(), rx); +} + +inline bool contains_none(const std::string& text, + const std::vector& words) { + for (const auto& w : words) + if (contains_word(text, w)) return false; + return true; +} + +inline std::string remove_font_modifiers(const std::string& s) { + std::string out; + out.reserve(s.size()); + + // bool inBacktick = false; + for (std::size_t i = 0; i < s.size(); ++i) { + char c = s[i]; + + // toggle backtick code span + if (c == '`') { + // inBacktick = !inBacktick; + continue; // drop the backtick itself + } + + // skip emphasis/strong/strike/escape chars as long as they're not preceeded + // by an escape character + if ((c == '*' || c == '_' || c == '~' || c == '\\') && s[i - 1] != '\\') + continue; + + // remove heading markers (#) at line starts + if ((c == '#') && (i == 0 || s[i - 1] == '\n')) continue; + + // drop leading '>' in blockquotes + if ((c == '>') && (i == 0 || s[i - 1] == '\n')) continue; + + out.push_back(c); + } + return out; +} + +inline std::string remove_first_line(const std::string& s) { + std::size_t pos = s.find('\n'); + return (pos == std::string::npos) ? std::string{} : s.substr(pos + 1); + // If there is no newline, removing the first line yields empty. +} + +inline std::string remove_last_line(const std::string& s) { + std::size_t pos = s.rfind('\n'); + return (pos == std::string::npos) ? std::string{} : s.substr(0, pos); + // If there is no newline, removing the last line yields empty. +} + +// Returns the 8 transformations as an array of strings. +// Index is a bitmask over {font_mod (bit0), remove_first (bit1), remove_last +// (bit2)}. + +// 000 (0) nothing +// 001 (1) font +// 010 (2) fl +// 011 (3) font & fl +// 100 (4) ll +// 101 (5) ll & font +// 110 (6) fl & ll +// 111 (7) all +inline std::array transform_response(const std::string& resp) { + std::array out{}; + for (int mask = 0; mask < 8; ++mask) { + std::string t = resp; + if (mask & 0b001) t = remove_font_modifiers(t); + if (mask & 0b010) t = remove_first_line(t); + if (mask & 0b100) t = remove_last_line(t); + out[mask] = std::move(t); + } + return out; +} + +} // namespace ifeval +} // namespace mobile +} // namespace mlperf +#endif // MLPERF_DATASETS_IFEVAL_UTILS_COMMON_H_ diff --git a/flutter/cpp/datasets/ifeval_utils/generate_tfrecords.py b/flutter/cpp/datasets/ifeval_utils/generate_tfrecords.py new file mode 100644 index 000000000..8e7dd25ab --- /dev/null +++ b/flutter/cpp/datasets/ifeval_utils/generate_tfrecords.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 + +import argparse +import json +import sys +from typing import Any, Dict, Iterable, List, Tuple + +import tensorflow as tf + + +def parse_args(): + p = argparse.ArgumentParser( + description="Convert a JSONL of IFEval prompts to TFRecord (Example) with key, prompt, instruction list, and kwargs/* features." + ) + p.add_argument("--input_file", type=str, required=True, help="Path to input JSONL file.") + p.add_argument("--output_file", type=str, required=True, help="Path to output TFRecord file.") + return p.parse_args() + + +def iter_jsonl(path: str) -> Iterable[Tuple[int, Dict[str, Any]]]: + with open(path, "r", encoding="utf-8") as f: + for lineno, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + yield lineno, json.loads(line) + except json.JSONDecodeError as e: + print(f"[warn] {path}:{lineno}: JSON decode error: {e}", file=sys.stderr) + + +def _bytes_feature(values: List[bytes]) -> tf.train.Feature: + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + + +def _int64_feature(values: List[int]) -> tf.train.Feature: + return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) + + +def _to_bytes(s: str) -> bytes: + return s.encode("utf-8") + + +def to_feature(value: Any) -> tf.train.Feature: + """ + Map a python value into a tf.train.Feature following supported types. + """ + if isinstance(value, bool): # check before int + return _int64_feature([int(value)]) + if isinstance(value, int): + return _int64_feature([value]) + if isinstance(value, str): + return _bytes_feature([_to_bytes(value)]) + + if isinstance(value, list): + if all(isinstance(x, str) for x in value): + return _bytes_feature([_to_bytes(x) for x in value]) + if all(isinstance(x, bool) for x in value): + return _int64_feature([int(x) for x in value]) + if all(isinstance(x, int) for x in value): + return _int64_feature([int(x) for x in value]) + + # Fallback: JSON-serialize and store as bytes + try: + s = json.dumps(value, ensure_ascii=False) + except Exception: + s = str(value) + return _bytes_feature([_to_bytes(s)]) + + +def build_example(record: Dict[str, Any]) -> tf.train.Example: + """ + Build a tf.train.Example with: + - "key": int64 (identifier) + - "prompt": bytes (UTF-8) + - "instruction_id_list": bytes_list of strings + - "kwargs//": features from kwargs[i][k] + """ + feats: Dict[str, tf.train.Feature] = {} + + # key (identifier) — ensure we write feature named "key" + key_val = record.get("key", None) + if isinstance(key_val, (int, bool)): + feats["key"] = _int64_feature([int(key_val)]) + else: + try: + feats["key"] = _int64_feature([int(key_val)]) + except Exception: + print(f"[warn] record has non-int key={key_val!r}; writing 0", file=sys.stderr) + feats["key"] = _int64_feature([0]) + + # prompt (optional but recommended) + prompt = record.get("prompt", "") + if not isinstance(prompt, str): + prompt = str(prompt) if prompt is not None else "" + feats["prompt"] = _bytes_feature([_to_bytes(prompt)]) + + # instruction_id_list + ids = record.get("instruction_id_list", []) + if not isinstance(ids, list) or not all(isinstance(x, str) for x in ids): + raise ValueError("Each record must contain 'instruction_id_list' as a list of strings.") + feats["instruction_id_list"] = _bytes_feature([_to_bytes(x) for x in ids]) + + # kwargs aligned by index + kwargs_list = record.get("kwargs", []) + if not isinstance(kwargs_list, list): + kwargs_list = [] + + for i, _id in enumerate(ids): + if i >= len(kwargs_list): + continue + entry = kwargs_list[i] + if not isinstance(entry, dict) or not entry: + continue + for k, v in entry.items(): + fname = f"kwargs/{i}/{k}" + feats[fname] = to_feature(v) + + return tf.train.Example(features=tf.train.Features(feature=feats)) + + +def main(): + args = parse_args() + options = tf.io.TFRecordOptions(compression_type="ZLIB") + + written = 0 + with tf.io.TFRecordWriter(args.output_file, options=options) as w: + for lineno, rec in iter_jsonl(args.input_file): + try: + ex = build_example(rec) + w.write(ex.SerializeToString()) + written += 1 + except Exception as e: + print(f"[warn] skipping line {lineno}: {e}", file=sys.stderr) + + print(f"[done] wrote {written} Example(s) -> {args.output_file}", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/flutter/cpp/datasets/ifeval_utils/types.h b/flutter/cpp/datasets/ifeval_utils/types.h new file mode 100644 index 000000000..cffce43ea --- /dev/null +++ b/flutter/cpp/datasets/ifeval_utils/types.h @@ -0,0 +1,683 @@ +#ifndef MLPERF_DATASETS_IFEVAL_UTILS_TYPES_H_ +#define MLPERF_DATASETS_IFEVAL_UTILS_TYPES_H_ + +#include +#include +#include +#include +#include +#include + +#include "flutter/cpp/datasets/ifeval_utils/common.h" + +namespace mlperf { +namespace mobile { +namespace ifeval { + +enum InstructionGroup { + CHANGE_CASE, + COMBINATION, + DETECTABLE_CONTENT, + DETECTABLE_FORMAT, + KEYWORDS, + LANGUAGE, + LENGTH_CONSTRAINTS, + PUNCTUATION, + STARTEND +}; + +enum Relation { AT_LEAST, LESS_THAN }; + +inline bool compare(size_t value, size_t threshold, Relation rel) { + if (rel == AT_LEAST) return value >= threshold; + return value < threshold; // LESS_THAN +} + +class Instruction { + public: + virtual ~Instruction() = default; + virtual constexpr InstructionGroup Group() = 0; + + bool IsFollowed(const std::string& resp, bool loose = false) const { + // For strict checks, just verify the response itself + if (!loose) return verify_(resp); + + auto transformations = transform_response(resp); + for (std::string transformation : transformations) { + if (verify_(resp)) return true; + } + return false; + } + + private: + virtual bool verify_(const std::string& resp) const = 0; +}; + +/* ---------- CHANGE_CASE ---------- */ + +class CapitalWordFrequency : public Instruction { + public: + CapitalWordFrequency(int capital_frequency, Relation capital_relation) + : threshold_(capital_frequency), rel_(capital_relation) {} + + constexpr InstructionGroup Group() override { return CHANGE_CASE; } + + private: + int threshold_; + Relation rel_; + + static size_t CapitalWords(const std::string& resp) { + size_t words = 0; + std::istringstream is(resp); + std::string w; + while (is >> w) { + size_t i = 0; + while (i < w.size() && !std::isalnum((unsigned char)w[i]) && + !std::isupper((unsigned char)w[i])) + ++i; + if (i >= w.size()) continue; + ++words; + } + return words; + } + + bool verify_(const std::string& resp) const override { + size_t words = CapitalWords(resp); + return compare(words, threshold_, rel_); + } +}; + +class EnglishCapital : public Instruction { + public: + EnglishCapital() = default; + constexpr InstructionGroup Group() override { return CHANGE_CASE; } + + private: + bool verify_(const std::string& resp) const override { + return std::all_of(resp.begin(), resp.end(), + [](unsigned char c) { return std::isupper(c); }); + } +}; + +class EnglishLowercase : public Instruction { + public: + EnglishLowercase() = default; + constexpr InstructionGroup Group() override { return CHANGE_CASE; } + + private: + bool verify_(const std::string& resp) const override { + return std::all_of(resp.begin(), resp.end(), + [](unsigned char c) { return std::islower(c); }); + } +}; + +/* ---------- COMBINATION ---------- */ + +class RepeatPrompt : public Instruction { + public: + explicit RepeatPrompt(std::string prompt_to_repeat) + : prompt_(std::move(prompt_to_repeat)) {} + constexpr InstructionGroup Group() override { return COMBINATION; } + + private: + std::string prompt_; + bool verify_(const std::string& resp) const override { + // TODO replace with startswith? + return contains_string(resp, prompt_); + } +}; + +class TwoResponses : public Instruction { + public: + TwoResponses() = default; + constexpr InstructionGroup Group() override { return COMBINATION; } + + private: + bool verify_(const std::string& resp) const override { + std::size_t count = 0; + std::size_t pos = resp.find("******"); + while (pos != std::string::npos) { + if (++count > 1) return false; // more than one occurrence + pos = resp.find("******", pos + 6); // disallow overlapping matches + } + return count > 0; + } +}; + +/* ------- DETECTABLE_CONTENT ------- */ + +class NumberPlaceholders : public Instruction { + public: + explicit NumberPlaceholders(int num_placeholders) : n_(num_placeholders) {} + constexpr InstructionGroup Group() override { return DETECTABLE_CONTENT; } + + private: + int n_; + bool verify_(const std::string& resp) const override { + std::size_t count = 0, pos = 0; + while (pos < resp.length() && + (int)count < n_) { // no need to keep looking if the requirement is + // already satisfied + std::size_t open = resp.find('[', pos); + if (open == std::string::npos) break; + std::size_t close = resp.find(']', open + 1); + if (close == std::string::npos) break; + + if (close > open + 1) { // non-empty inner + const std::string inner = resp.substr(open + 1, close - open - 1); + bool ok = true; + for (unsigned char ch : inner) { + if (std::isspace(ch) || !(std::isalnum(ch) || ch == '_')) { + ok = false; + break; + } + } + if (ok) ++count; + } + pos = close + 1; // continue after this closing bracket + } + return (int)count >= n_; + } +}; + +class Postscript : public Instruction { + public: + explicit Postscript(std::string postscript_marker) + : marker_(std::move(postscript_marker)) {} + constexpr InstructionGroup Group() override { return DETECTABLE_CONTENT; } + + private: + std::string marker_; + bool verify_(const std::string& resp) const override { + return contains_string(resp, marker_); + } +}; + +/* ------- DETECTABLE_FORMAT -------- */ + +class ConstrainedResponse : public Instruction { + public: + ConstrainedResponse() = default; + constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } + + private: + bool verify_(const std::string& resp) const override { + return resp == "My answer is yes." || resp == "My answer is no." || + resp == "My answer is maybe."; + } +}; + +class JsonFormat : public Instruction { + public: + JsonFormat() = default; + constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } + + private: + // TODO possibly use a C++ json validator instead + bool verify_(const std::string& resp) const override { + std::string t = resp; + if (t.empty()) return false; + if (!((t.front() == '{' && t.back() == '}') || + (t.front() == '[' && t.back() == ']'))) + return false; + int brace = 0, bracket = 0; + bool in_str = false, esc = false; + for (char c : t) { + if (esc) { + esc = false; + continue; + } + if (c == '\\') { + esc = true; + continue; + } + if (c == '"') { + in_str = !in_str; + continue; + } + if (in_str) continue; + if (c == '{') + ++brace; + else if (c == '}') + --brace; + else if (c == '[') + ++bracket; + else if (c == ']') + --bracket; + if (brace < 0 || bracket < 0) return false; + } + return brace == 0 && bracket == 0 && !in_str; + } +}; + +class MultipleSections : public Instruction { + public: + MultipleSections(int num_sections, std::string section_spliter) + : n_(num_sections), sep_(std::move(section_spliter)) {} + constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } + + private: + int n_; + std::string sep_; + static int CountNonEmpty(const std::vector& v) { + int c = 0; + for (auto& p : v) + if (!trim(p).empty()) ++c; + return c; + } + inline std::vector SplitByDelim(const std::string& s, + const std::string& delim) const { + if (delim.empty()) return {s}; + std::vector parts; + size_t start = 0; + while (true) { + size_t pos = s.find(delim, start); + if (pos == std::string::npos) { + parts.push_back(s.substr(start)); + break; + } + parts.push_back(s.substr(start, pos - start)); + start = pos + delim.size(); + } + return parts; + } + bool verify_(const std::string& resp) const override { + auto parts = SplitByDelim(resp, sep_); + return CountNonEmpty(parts) == n_; + } +}; + +class NumberBulletLists : public Instruction { + public: + explicit NumberBulletLists(int num_bullets) : n_(num_bullets) {} + constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } + + private: + int n_; + + inline std::vector SplitLines(const std::string& s) const { + std::vector out; + std::string cur; + std::istringstream is(s); + while (std::getline(is, cur)) out.push_back(cur); + return out; + } + + bool verify_(const std::string& resp) const override { + size_t count = 0; + for (const auto& line : SplitLines(resp)) { + std::string t = trim(line); + if (t.rfind("* ", 0) == 0) { + ++count; + continue; + } + } + return (int)count == n_; + } +}; + +class NumberHighlightedSections : public Instruction { + public: + explicit NumberHighlightedSections(int num_highlights) : n_(num_highlights) {} + constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } + + private: + int n_; + bool verify_(const std::string& resp) const override { + std::size_t count = 0; + std::size_t pos = 0; + + while (true) { + // find opening '*' + std::size_t open = resp.find('*', pos); + if (open == std::string::npos) break; + + // need at least one non-*\r\n char after the opener + if (open + 1 >= resp.size()) break; + char next = resp[open + 1]; + if (next == '*' || next == '\n' || next == '\r') { + pos = open + 1; // not a valid start; try from the next '*' + continue; + } + + // find the first '*' or newline after the opener + std::size_t stop = resp.find_first_of("*\r\n", open + 1); + if (stop == std::string::npos) break; + + if (resp[stop] == '*') { + // we have "*...*" with no '*' or newline inside + ++count; + pos = stop + 1; // continue after the closing '*' + } else { + // newline encountered before a closing '*': this opener can't match + pos = stop + 1; // continue scanning after the newline + } + } + return (int)count >= n_; + } +}; + +class Title : public Instruction { + public: + Title() = default; + constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } + + private: + bool verify_(const std::string& resp) const override { + std::size_t pos_open = resp.find("<<"); + // TODO should an empty title be allowed? + return (pos_open != std::string::npos) && + (resp.find(">>", pos_open + 2) != + std::string::npos); // found "<<" with a following ">>" + } +}; + +/* -------------- KEYWORDS -------------- */ + +class Existence : public Instruction { + public: + explicit Existence(std::vector keywords) + : kws_(std::move(keywords)) {} + constexpr InstructionGroup Group() override { return KEYWORDS; } + + private: + std::vector kws_; + bool verify_(const std::string& resp) const override { + for (const auto& k : kws_) + if (!contains_word(resp, k)) return false; + return true; + } +}; + +class ForbiddenWords : public Instruction { + public: + explicit ForbiddenWords(std::vector forbidden_words) + : bad_(std::move(forbidden_words)) {} + constexpr InstructionGroup Group() override { return KEYWORDS; } + + private: + std::vector bad_; + bool verify_(const std::string& resp) const override { + return contains_none(resp, bad_); + } +}; + +class Frequency : public Instruction { + public: + Frequency(int frequency, std::string keyword, Relation relation) + : n_(frequency), kw_(std::move(keyword)), rel_(relation) {} + constexpr InstructionGroup Group() override { return KEYWORDS; } + + private: + int n_; + std::string kw_; + Relation rel_; + bool verify_(const std::string& resp) const override { + std::regex rx("\\b" + kw_ + "\\b", std::regex::icase); + size_t count = 0; + auto it = std::sregex_iterator(resp.begin(), resp.end(), rx); + auto end = std::sregex_iterator(); + for (; it != end; ++it) ++count; + return compare(count, (size_t)n_, rel_); + } +}; + +class LetterFrequency : public Instruction { + public: + LetterFrequency(int let_frequency, char letter, Relation let_relation) + : n_(let_frequency), letter_(letter), rel_(let_relation) {} + constexpr InstructionGroup Group() override { return KEYWORDS; } + + private: + int n_; + char letter_; + Relation rel_; + static size_t CountLetterICase(const std::string& s, char letter) { + size_t c = 0; + char lower = std::tolower((unsigned char)letter); + for (unsigned char ch : s) + if (std::tolower(ch) == lower) ++c; + return c; + } + bool verify_(const std::string& resp) const override { + size_t c = CountLetterICase(resp, letter_); + return compare(c, (size_t)n_, rel_); + } +}; + +/* -------------- LANGUAGE -------------- */ + +class ResponseLanguage : public Instruction { + public: + explicit ResponseLanguage(std::string language) + : lang_(std::move(language)) {} + constexpr InstructionGroup Group() override { return LANGUAGE; } + + private: + std::string lang_; + + inline bool LanguageHeuristic(const std::string& text, + const std::string& lang) const { + std::string L = tolower(lang); + const std::string& t = text; + + auto non_ascii_ratio = [&]() { + size_t non_ascii = 0, total = 0; + for (unsigned char c : t) { + if (std::isalpha(c)) { + ++total; + if (c >= 128) ++non_ascii; + } + } + return total == 0 ? 0.0 : (double)non_ascii / (double)total; + }; + + if (L == "en") { + return non_ascii_ratio() < 0.05; + } + if (L == "tr") { + return t.find("ÄŸ") != std::string::npos || + t.find("Äž") != std::string::npos || + t.find("ÅŸ") != std::string::npos || + t.find("Åž") != std::string::npos || + t.find("ı") != std::string::npos || + t.find("İ") != std::string::npos || + t.find("ö") != std::string::npos || + t.find("Ö") != std::string::npos || + t.find("ç") != std::string::npos || + t.find("Ç") != std::string::npos || + t.find("ü") != std::string::npos || + t.find("Ü") != std::string::npos; + } + if (L == "es") { + return t.find("ñ") != std::string::npos || + t.find("Ñ") != std::string::npos || + t.find("á") != std::string::npos || + t.find("é") != std::string::npos || + t.find("í") != std::string::npos || + t.find("ó") != std::string::npos || + t.find("ú") != std::string::npos; + } + if (L == "fr") { + return t.find("é") != std::string::npos || + t.find("è") != std::string::npos || + t.find("ê") != std::string::npos || + t.find("ç") != std::string::npos || + t.find("à") != std::string::npos; + } + if (L == "de") { + return t.find("ä") != std::string::npos || + t.find("ö") != std::string::npos || + t.find("ü") != std::string::npos || + t.find("ß") != std::string::npos; + } + return non_ascii_ratio() > 0.05; + } + + bool verify_(const std::string& resp) const override { + return LanguageHeuristic(resp, lang_); + } +}; + +/* ----------- LENGTH_CONSTRAINTS ----------- */ + +class NthParagraphFirstWord : public Instruction { + public: + NthParagraphFirstWord(int nth_paragraph, std::string first_word, + int num_paragraphs) + : nth_(nth_paragraph), + first_(std::move(first_word)), + total_(num_paragraphs) {} + constexpr InstructionGroup Group() override { return LENGTH_CONSTRAINTS; } + + private: + int nth_; + std::string first_; + int total_; + + static std::string FirstWord(const std::string& s) { + std::istringstream is(s); + std::string w; + is >> w; + return tolower(w); + } + + static inline std::vector SplitParagraphs(const std::string& s) { + // paragraphs separated only by the literal delimiter "\n\n" + std::vector paras; + std::size_t start = 0; + while (true) { + std::size_t pos = s.find("\n\n", start); + if (pos == std::string::npos) { + std::string chunk = s.substr(start); + if (!chunk.empty()) paras.push_back(rtrim(chunk)); + break; + } + std::string chunk = s.substr(start, pos - start); + if (!chunk.empty()) paras.push_back(rtrim(chunk)); + start = pos + 2; // skip the delimiter + } + return paras; + } + + bool verify_(const std::string& resp) const override { + auto paras = SplitParagraphs(resp); + if ((int)paras.size() != total_) return false; + if (nth_ <= 0 || nth_ > (int)paras.size()) return false; + auto target = trim(paras[nth_ - 1]); + if (target.empty()) return false; + return FirstWord(target) == tolower(first_); + } +}; + +class NumberParagraphs : public Instruction { + public: + explicit NumberParagraphs(int num_paragraphs) : n_(num_paragraphs) {} + constexpr InstructionGroup Group() override { return LENGTH_CONSTRAINTS; } + + private: + unsigned n_; + bool verify_(const std::string& resp) const override { + std::size_t count = 0, pos = 0; + while ((pos = resp.find("***", pos)) != std::string::npos) { + ++count; + pos += 3; // advance by 3 for non-overlapping matches + } + return count + 1 == n_; // since *** is a saparator, the actual count is 1 + // more than the number of separators + } +}; + +class NumberSentences : public Instruction { + public: + NumberSentences(int num_sentences, Relation relation) + : n_(num_sentences), rel_(relation) {} + constexpr InstructionGroup Group() override { return LENGTH_CONSTRAINTS; } + + private: + int n_; + Relation rel_; + bool verify_(const std::string& resp) const override { + size_t count = 0; + for (unsigned char c : resp) { + if (c == '.' || c == '!' || c == '?') ++count; + } + return compare(count, (size_t)n_, rel_); + } +}; + +class NumberWords : public Instruction { + public: + NumberWords(int num_words, Relation relation) + : n_(num_words), rel_(relation) {} + constexpr InstructionGroup Group() override { return LENGTH_CONSTRAINTS; } + + private: + int n_; + Relation rel_; + bool verify_(const std::string& resp) const override { + size_t count = 0; + bool in_word = false; + for (unsigned char c : resp) { + if (std::isalnum(c)) { + if (!in_word) { + in_word = true; + ++count; + } + } else + in_word = false; + } + return compare(count, (size_t)n_, rel_); + } +}; + +/* -------------- PUNCTUATION -------------- */ + +class NoComma : public Instruction { + public: + NoComma() = default; + constexpr InstructionGroup Group() override { return PUNCTUATION; } + + private: + bool verify_(const std::string& resp) const override { + return resp.find(',') == std::string::npos; + } +}; + +/* ---------------- STARTEND ---------------- */ + +class EndChecker : public Instruction { + public: + explicit EndChecker(std::string end_phrase) : end_(std::move(end_phrase)) {} + constexpr InstructionGroup Group() override { return STARTEND; } + + private: + std::string end_; + bool verify_(const std::string& resp) const override { + return ends_with(resp, end_); + } +}; + +class Quotation : public Instruction { + public: + Quotation() = default; + constexpr InstructionGroup Group() override { return STARTEND; } + + private: + bool verify_(const std::string& resp) const override { + if (resp.size() < 2) return false; + return resp.front() == '"' && resp.back() == '"'; + } +}; + +struct Sample { + int key; + std::string prompt; + std::vector input_tokens; + std::vector> instructions; +}; + +} // namespace ifeval +} // namespace mobile +} // namespace mlperf +#endif // MLPERF_DATASETS_IFEVAL_UTILS_TYPES_H_ diff --git a/flutter/cpp/proto/mlperf_task.proto b/flutter/cpp/proto/mlperf_task.proto index e87f47d1c..7da969bcd 100644 --- a/flutter/cpp/proto/mlperf_task.proto +++ b/flutter/cpp/proto/mlperf_task.proto @@ -69,7 +69,7 @@ message OneRunConfig { // Datasets for a task // -// Next ID: 8 +// Next ID: 9 message DatasetConfig { // Type of the dataset. enum DatasetType { @@ -81,6 +81,7 @@ message DatasetConfig { SNUSR = 5; COCOGEN = 6; MMLU = 7; + IFEVAL = 8; } required DatasetType type = 1; // Config of the dataset. diff --git a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt index a784857a5..5872cfd02 100644 --- a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt +++ b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt @@ -279,8 +279,27 @@ benchmark_setting { id: "pipeline" value: "LLMPipeline" } +} + +benchmark_setting { + benchmark_id: "llm_instruction" + framework: "TFLite" + delegate_choice: { + delegate_name: "CPU" + accelerator_name: "cpu" + accelerator_desc: "CPU" + model_file: { + model_path: "local:///mlperf_models/llama_q8_ekv3072.tflite" + model_checksum: "54efe0be372b55303673245067beef62" + } + model_file: { + model_path: "local:///mlperf_models/llama3_1b.spm.model" + model_checksum: "2ad260fc18b965ce16006d76c9327082" + } + } + delegate_selected: "CPU" custom_setting { - id: "sentencepiece_processor_path" - value: "llama3_1b.spm.model" + id: "pipeline" + value: "LLMPipeline" } } diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 5ca662d60..cb5696799 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -134,7 +134,7 @@ struct LLMBackendData { std::vector prompt_tokens; std::vector output_tokens; uint8_t threads = 30; - int max_output_tokens = 2; + int max_output_tokens = 2048; int stop_token_id = -1; LLMBackendData() {} From 97ba25f46eab74cca619e1153d1c6ab3c98a8ad4 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Sat, 11 Oct 2025 09:30:19 +0300 Subject: [PATCH 32/74] disabled XNNPACK AVX-VNNI for windows due to C2440 error --- .bazelrc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.bazelrc b/.bazelrc index 91a29048e..f086b5733 100644 --- a/.bazelrc +++ b/.bazelrc @@ -107,6 +107,10 @@ build:windows --host_linkopt=/OPT:REF build:windows --linkopt=/OPT:ICF build:windows --host_linkopt=/OPT:ICF +# MSVC does not support XNNPACK AVXVNNI instructions (causes C2440 error). +build:windows --define=xnn_enable_avxvnni=false +#build:windows --define=xnn_enable_avxvnniint8=false + # Address sanitizer build:asan --strip=never build:asan --copt -fsanitize=address From 5383e758d391d21fe5b8a1d09f90cac0796e2580 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 13 Oct 2025 07:20:23 +0300 Subject: [PATCH 33/74] moved accuracy calculation away from ProcessOutput, ifeval accuracy is calculated per instruction not per sample --- flutter/cpp/datasets/ifeval.cc | 48 +++++++++++++++----------------- flutter/cpp/datasets/ifeval.h | 10 +++++-- flutter/cpp/datasets/mmlu_gen.cc | 41 ++++++++++++++++----------- flutter/cpp/datasets/mmlu_gen.h | 10 +++---- 4 files changed, 59 insertions(+), 50 deletions(-) diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index ef31259b7..6c77ce443 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -44,7 +44,7 @@ IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, sample->instructions = std::move(instructions); samples_.push_back(std::move(sample)); - sample_output_token_counts_.push_back(0); + sample_output_tokens_.push_back(std::vector()); } } @@ -79,44 +79,39 @@ std::vector IFEval::ProcessOutput(const int sample_idx, const auto& output_tokens = *(reinterpret_cast*>(outputs[0])); - LOG(INFO) << '[' - << std::accumulate(std::next(output_tokens.begin()), - output_tokens.end(), - std::to_string(output_tokens[0]), - [](std::string a, int b) { - return std::move(a) + ", " + std::to_string(b); - }) - << "]\n"; + sample_output_tokens_[sample_idx] = output_tokens; - sample_output_token_counts_[sample_idx] = output_tokens.size(); + return {1}; +} + +int64_t IFEval::GetOutputTokenCount(const int sample_idx) { + return sample_output_tokens_[sample_idx].size(); +} +bool IFEval::HasAccuracy() { return true; } + +bool IFEval::ComputeSampleAccuracy(const int sample_idx, + ifeval::GroupAccuracy& accuracy) { std::string prediction; - sp_processor->Decode(output_tokens, &prediction).ok(); + sp_processor->Decode(sample_output_tokens_[sample_idx], &prediction).ok(); LOG(INFO) << "output(" << std::to_string(sample_idx) << "): " << prediction << std::endl; - bool is_correct = true; // Automatically pass samples with no instructions. - std::vector groups; for (const auto& instruction : samples_[sample_idx]->instructions) { - is_correct &= instruction->IsFollowed(prediction, loose_follow_); - groups.emplace_back(instruction->Group()); + bool is_correct = instruction->IsFollowed(prediction, loose_follow_); + ProcessResult(instruction->Group(), is_correct, accuracy); } - - for (auto group : groups) ProcessResult(group, is_correct); - - return {static_cast(is_correct)}; } -int64_t IFEval::GetOutputTokenCount(const int sample_idx) { - return sample_output_token_counts_[sample_idx]; -} - -bool IFEval::HasAccuracy() { return true; } - float IFEval::ComputeAccuracy() { uint16_t correct_sum; uint16_t total_sum; + ifeval::GroupAccuracy accuracy; + + for (auto sample_id : used_sample_ids_) { + ComputeSampleAccuracy(sample_id, accuracy); + } correct_sum += accuracy.change_case_correct; correct_sum += accuracy.combination_correct; @@ -431,7 +426,8 @@ IFEval::BuildInstructions(const tensorflow::Example& ex) { } inline void IFEval::ProcessResult(ifeval::InstructionGroup group, - bool is_correct) { + bool is_correct, + ifeval::GroupAccuracy& accuracy) { uint8_t correct_value = is_correct ? 1 : 0; switch (group) { case ifeval::InstructionGroup::CHANGE_CASE: diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h index 80866aad5..27e1bd877 100644 --- a/flutter/cpp/datasets/ifeval.h +++ b/flutter/cpp/datasets/ifeval.h @@ -56,6 +56,9 @@ class IFEval : public Dataset { bool HasAccuracy() override; + bool ComputeSampleAccuracy(const int sample_idx, + ifeval::GroupAccuracy& accuracy); + float ComputeAccuracy() override; std::string ComputeAccuracyString() override; @@ -63,7 +66,8 @@ class IFEval : public Dataset { inline std::vector> BuildInstructions( const tensorflow::Example& ex); - inline void ProcessResult(ifeval::InstructionGroup group, bool is_correct); + inline void ProcessResult(ifeval::InstructionGroup group, bool is_correct, + ifeval::GroupAccuracy& accuracy); private: const std::string name_ = "IFEval"; @@ -71,11 +75,11 @@ class IFEval : public Dataset { TFRecordReader sample_reader_; std::vector> samples_; - std::vector sample_output_token_counts_; + std::vector> sample_output_tokens_; + std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; - ifeval::GroupAccuracy accuracy; bool loose_follow_; std::string start_token = ""; diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 67666ce97..b2b2a415a 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -10,7 +10,6 @@ namespace mlperf { namespace mobile { -// TODO add eos and bos tokens as config parameters MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, const std::string& sp_path, bool zero_shot) : sample_reader_(input_tfrecord), Dataset(backend) { @@ -48,7 +47,7 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, sample->answer = answer; samples_.push_back(std::move(sample)); - sample_output_token_counts_.push_back(0); + sample_output_tokens_.push_back(std::vector()); } } @@ -83,30 +82,40 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, const auto& output_tokens = *(reinterpret_cast*>(outputs[0])); - sample_output_token_counts_[sample_idx] = output_tokens.size(); + sample_output_tokens_[sample_idx] = output_tokens; + used_sample_ids_.insert(sample_idx); + return {1}; +} + +int64_t MmluGen::GetOutputTokenCount(const int sample_idx) { + return sample_output_tokens_[sample_idx].size(); +} + +bool MmluGen::HasAccuracy() { return true; } + +bool MmluGen::ComputeSampleAccuracy(const int sample_idx) { std::string prediction; - sp_processor->Decode(output_tokens, &prediction).ok(); + sp_processor->Decode(sample_output_tokens_[sample_idx], &prediction).ok(); + + LOG(INFO) << "index: " << std::to_string(sample_idx) << std::endl; + LOG(INFO) << "Output: [[[" << prediction << "]]]" << std::endl; char predicted_char = find_answer_char(prediction); const std::string& correct = samples_[sample_idx]->answer; - bool is_correct = (predicted_char == correct[0]); - - total_++; - if (is_correct) correct_++; - - return {static_cast(is_correct)}; + return (predicted_char == correct[0]); } -int64_t MmluGen::GetOutputTokenCount(const int sample_idx) { - return sample_output_token_counts_[sample_idx]; -} +float MmluGen::ComputeAccuracy() { + int total(0), correct(0); -bool MmluGen::HasAccuracy() { return true; } + for (auto sample_id : used_sample_ids_) { + total++; + if (ComputeSampleAccuracy(sample_id)) correct++; + } -float MmluGen::ComputeAccuracy() { - return total_ > 0 ? static_cast(correct_) / total_ : 0.0f; + return total > 0 ? static_cast(correct) / total : 0.0f; } std::string MmluGen::ComputeAccuracyString() { diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 8c21caaaa..180d41a10 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include "flutter/cpp/dataset.h" @@ -42,6 +42,8 @@ class MmluGen : public Dataset { bool HasAccuracy() override; + bool ComputeSampleAccuracy(const int sample_idx); + float ComputeAccuracy() override; std::string ComputeAccuracyString() override; @@ -60,13 +62,11 @@ class MmluGen : public Dataset { }; std::vector> samples_; - std::vector sample_output_token_counts_; + std::vector> sample_output_tokens_; + std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; - size_t correct_ = 0; - size_t total_ = 0; - std::string start_token = ""; std::string end_token = ""; int start_token_id; From 8e21ed10779353052f194fc3829a985edf7f5f52 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 20 Oct 2025 02:02:27 +0300 Subject: [PATCH 34/74] fixed issue with app not finding model/tokenizer --- flutter/cpp/datasets/ifeval.cc | 9 ++-- flutter/cpp/datasets/ifeval.h | 5 --- flutter/cpp/datasets/mmlu_gen.cc | 23 +++++++--- flutter/cpp/datasets/mmlu_gen.h | 5 --- .../datasets/mmlu_utils/sentencepiece_utils.h | 43 ++++++++++++++++++- flutter/cpp/flutter/dart_run_benchmark.cc | 6 ++- flutter/lib/data/results/dataset_info.dart | 4 ++ .../tflite_settings_android.pbtxt | 25 +++-------- .../cpp/backend_tflite/llm_pipeline.cc | 32 ++++++++------ .../cpp/backend_tflite/llm_pipeline.h | 7 +-- 10 files changed, 98 insertions(+), 61 deletions(-) diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index 6c77ce443..34124de2a 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -17,8 +17,6 @@ IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, Dataset(backend) { sp_processor = std::unique_ptr( LoadSentencePieceProcessor(sp_path)); - start_token_id = sp_processor->PieceToId(start_token); - end_token_id = sp_processor->PieceToId(end_token); // Load all TFRecord samples into memory // NOTE this can be moved to LoadSamplesToRam, but will cause delays between @@ -32,10 +30,10 @@ IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, tensorflow::GetFeatureValues("prompt", example).Get(0); auto instructions = BuildInstructions(example); - std::vector input_tokens; - sp_processor->Encode(prompt.c_str(), &input_tokens).ok(); - input_tokens.insert(input_tokens.begin(), start_token_id); + std::string input_formatted = FormatLlamaUserPrompt(prompt); + std::vector input_tokens; + sp_processor->Encode(input_formatted.c_str(), &input_tokens).ok(); auto sample = std::make_unique(); sample->key = key; @@ -67,7 +65,6 @@ std::vector IFEval::GetData(int sample_idx) { if (sample_idx < samples_.size()) { data.push_back(reinterpret_cast( const_cast*>(&(samples_[sample_idx]->input_tokens)))); - data.push_back(reinterpret_cast(const_cast(&end_token_id))); } return data; } diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h index 27e1bd877..9d07fef27 100644 --- a/flutter/cpp/datasets/ifeval.h +++ b/flutter/cpp/datasets/ifeval.h @@ -81,11 +81,6 @@ class IFEval : public Dataset { std::unique_ptr sp_processor; bool loose_follow_; - - std::string start_token = ""; - std::string end_token = ""; - int start_token_id; - int end_token_id; }; } // namespace mobile diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index b2b2a415a..47c2d1c73 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -7,6 +7,9 @@ #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature_util.h" + +#define li LOG(INFO) << "li:" << __FILE__ << ":" << __LINE__ << "@" << __func__ + namespace mlperf { namespace mobile { @@ -15,8 +18,8 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, : sample_reader_(input_tfrecord), Dataset(backend) { sp_processor = std::unique_ptr( LoadSentencePieceProcessor(sp_path)); - start_token_id = sp_processor->PieceToId(start_token); - end_token_id = sp_processor->PieceToId(end_token); + + li; // Load all TFRecord samples into memory // NOTE this can be moved to LoadSamplesToRam, but will cause delays between @@ -35,12 +38,11 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, input.rfind("\n\n") + 2); // input-formatted shots are separated by 2 new lines - std::vector input_tokens; + std::string input_formatted = FormatLlamaUserPrompt(input, "Provide only the answer letter, do not provide any explanation or preface."); + std::vector input_tokens; sp_processor->Encode(input.c_str(), &input_tokens).ok(); - input_tokens.insert(input_tokens.begin(), start_token_id); - auto sample = std::make_unique(); sample->input = input; sample->input_tokens = input_tokens; @@ -49,6 +51,7 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, samples_.push_back(std::move(sample)); sample_output_tokens_.push_back(std::vector()); } + li; } void MmluGen::LoadSamplesToRam(const std::vector& samples) { @@ -67,10 +70,11 @@ void MmluGen::UnloadSamplesFromRam( std::vector MmluGen::GetData(int sample_idx) { std::vector data; + li; + LOG(INFO) << "Sample ID: " << std::to_string(sample_idx); if (sample_idx < samples_.size()) { data.push_back(reinterpret_cast( const_cast*>(&(samples_[sample_idx]->input_tokens)))); - data.push_back(reinterpret_cast(const_cast(&end_token_id))); } return data; } @@ -79,12 +83,17 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, const std::vector& outputs) { if (sample_idx >= samples_.size() || outputs.empty()) return {0}; + li; const auto& output_tokens = *(reinterpret_cast*>(outputs[0])); + li; sample_output_tokens_[sample_idx] = output_tokens; used_sample_ids_.insert(sample_idx); + li; + LOG(INFO) << "Processed " << std::to_string(used_sample_ids_.size()) << "/100"; + return {1}; } @@ -98,12 +107,14 @@ bool MmluGen::ComputeSampleAccuracy(const int sample_idx) { std::string prediction; sp_processor->Decode(sample_output_tokens_[sample_idx], &prediction).ok(); + li; LOG(INFO) << "index: " << std::to_string(sample_idx) << std::endl; LOG(INFO) << "Output: [[[" << prediction << "]]]" << std::endl; char predicted_char = find_answer_char(prediction); const std::string& correct = samples_[sample_idx]->answer; + li; return (predicted_char == correct[0]); } diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index 180d41a10..ce8189418 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -66,11 +66,6 @@ class MmluGen : public Dataset { std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; - - std::string start_token = ""; - std::string end_token = ""; - int start_token_id; - int end_token_id; }; } // namespace mobile diff --git a/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h b/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h index f4a2a76fe..5a7d36ecb 100644 --- a/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h +++ b/flutter/cpp/datasets/mmlu_utils/sentencepiece_utils.h @@ -17,13 +17,14 @@ limitations under the License. #define MLPERF_DATASETS_MMLU_UTILS_SENTENCEPIECE_UTILS_H_ #include +#include #include "src/sentencepiece_processor.h" namespace mlperf { namespace mobile { -static sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor( +static sentencepiece::SentencePieceProcessor* LoadSentencePieceProcessor( std::string path) { std::ifstream input(path, std::ios::binary); std::string serialized_proto = std::string( @@ -33,6 +34,46 @@ static sentencepiece::SentencePieceProcessor *LoadSentencePieceProcessor( return processor; } +inline static std::string FormatLlamaUserPrompt( + std::string_view user_content, std::string_view system_content = "", + bool add_generation_prompt = true, bool add_bos = true) { + static constexpr const char* kBOS = "<|begin_of_text|>"; + static constexpr const char* kStartHeader = "<|start_header_id|>"; + static constexpr const char* kEndHeader = "<|end_header_id|>"; + static constexpr const char* kEOT = "<|eot_id|>"; + + std::string out; + out.reserve(user_content.size() + 64); + + if (add_bos) out += kBOS; + + if (!system_content.empty()) { + out += kStartHeader; + out += "system"; + out += kEndHeader; + out += "\n"; + out.append(system_content); + out += "\n"; + out += kEOT; + } + + out += kStartHeader; + out += "user"; + out += kEndHeader; + out += "\n"; + out.append(user_content); + out += "\n"; + out += kEOT; + + if (add_generation_prompt) { + out += kStartHeader; + out += "assistant"; + out += kEndHeader; + out += "\n"; + } + return out; +} + } // namespace mobile } // namespace mlperf #endif // MLPERF_DATASETS_MMLU_UTILS_SENTENCEPIECE_UTILS_H_ diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index ef7b0bac4..fc4dd10b1 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -75,6 +75,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( ::std::unique_ptr<::mlperf::mobile::Dataset> dataset; std::string sp_path; + std::string sp_path_filename; switch (in->dataset_type) { case ::mlperf::mobile::DatasetConfig::IMAGENET: dataset = std::make_unique<::mlperf::mobile::Imagenet>( @@ -109,8 +110,11 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( break; case ::mlperf::mobile::DatasetConfig::MMLU: for (auto setting : settings.benchmark_setting().custom_setting()) { - if (setting.id() == "llm_tokenizer_path") sp_path = setting.value(); + if (setting.id() == "tokenizer_filename") sp_path_filename = setting.value(); } + sp_path = in->backend_model_path; + sp_path += '/' + sp_path_filename; + LOG(INFO) << "SP path: " << sp_path; dataset = std::make_unique<::mlperf::mobile::MmluGen>( backend.get(), in->dataset_data_path, sp_path, true /*zero-shot*/); break; diff --git a/flutter/lib/data/results/dataset_info.dart b/flutter/lib/data/results/dataset_info.dart index 03cda37cb..458c0b84d 100644 --- a/flutter/lib/data/results/dataset_info.dart +++ b/flutter/lib/data/results/dataset_info.dart @@ -15,6 +15,10 @@ enum DatasetTypeEnum { snusr, @JsonValue('COCOGEN') cocogen, + @JsonValue('MMLU') + mmlu, + @JsonValue('IFEVAL') + ifeval, } extension DatasetTypeExtension on DatasetTypeEnum { diff --git a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt index 5872cfd02..0c119aa6c 100644 --- a/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt +++ b/mobile_back_tflite/cpp/backend_tflite/backend_settings/tflite_settings_android.pbtxt @@ -279,27 +279,12 @@ benchmark_setting { id: "pipeline" value: "LLMPipeline" } -} - -benchmark_setting { - benchmark_id: "llm_instruction" - framework: "TFLite" - delegate_choice: { - delegate_name: "CPU" - accelerator_name: "cpu" - accelerator_desc: "CPU" - model_file: { - model_path: "local:///mlperf_models/llama_q8_ekv3072.tflite" - model_checksum: "54efe0be372b55303673245067beef62" - } - model_file: { - model_path: "local:///mlperf_models/llama3_1b.spm.model" - model_checksum: "2ad260fc18b965ce16006d76c9327082" - } + custom_setting { + id: "model_filename" + value: "llama_q8_ekv3072.tflite" } - delegate_selected: "CPU" custom_setting { - id: "pipeline" - value: "LLMPipeline" + id: "tokenizer_filename" + value: "llama3_1b.spm.model" } } diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 5b2d12369..6d362f169 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -55,9 +55,14 @@ mlperf_backend_ptr_t LLMPipeline::backend_create( LLMBackendData *backend_data = new LLMBackendData(); + std::string model_filename = mlperf::mobile::GetConfigValue( + configs, "model_filename", std::string("")); + + std::string llm_model_path = std::string(model_path) + '/' + model_filename; + // Load the model. backend_data->model = - tflite::FlatBufferModel::BuildFromFile(model_path).release(); + tflite::FlatBufferModel::BuildFromFile(llm_model_path.c_str()).release(); if (!backend_data->model) { LOG(ERROR) << "Failed to load model: " << model_path; backend_delete(backend_data); @@ -125,11 +130,16 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query( backend_data->tensors.prefill_input_pos()->bytes); // If the prefill can fit the entire input, leave one token for decode, // otherwise prefill as much of the input as possible. - for (int i = 0; i < prefill_amount; ++i) { + int i = 0; + for (; i < prefill_amount; ++i) { backend_data->tensors.prefill_input()->data.i32[i] = backend_data->prompt_tokens[i]; backend_data->tensors.prefill_input_pos()->data.i32[i] = i; } + for (; i < max_seq_size; ++i) { + backend_data->tensors.prefill_input()->data.i32[i] = 128009; + backend_data->tensors.prefill_input_pos()->data.i32[i] = i; + } MINIMAL_CHECK(backend_data->prefill_runner->Invoke() == kTfLiteOk); @@ -169,7 +179,8 @@ mlperf_status_t LLMPipeline::backend_issue_query( backend_data->output_tokens.reserve(decode_steps); int next_token = GreedySampler(backend_data->tensors.logits_output()); - if (next_token == backend_data->stop_token_id) return MLPERF_SUCCESS; + for (int stop_token_id : backend_data->stop_token_ids) + if (next_token == stop_token_id) return MLPERF_SUCCESS; backend_data->output_tokens.push_back(next_token); int next_position = input_size; for (int i = 0; i < decode_steps; ++i) { @@ -177,7 +188,8 @@ mlperf_status_t LLMPipeline::backend_issue_query( backend_data->tensors.decode_input_pos()->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); next_token = GreedySampler(backend_data->tensors.logits_output()); - if (next_token == backend_data->stop_token_id) break; + for (int stop_token_id : backend_data->stop_token_ids) + if (next_token == stop_token_id) break; backend_data->output_tokens.push_back(next_token); next_position += 1; } @@ -192,10 +204,10 @@ mlperf_status_t LLMPipeline::backend_flush_queries( } // Return the number of inputs of the model. -// Only 2 inputs need to be provided, the tokens themselves, and the EOS token. +// Only 1 input need to be provided, the tokens themselves. // The other inputs are handled by the pipeline int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { - return 2; + return 1; } // Return the type of the ith input. @@ -209,14 +221,6 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, int32_t batch_index, int32_t i, void *data) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; - - if (i == 1) { - backend_data->stop_token_id = *(reinterpret_cast(data)); - LOG(INFO) << "stop token id: " - << std::to_string(backend_data->stop_token_id) << std::endl; - return MLPERF_SUCCESS; - } - // Reset the tokens and kv caches from potential previous runs. backend_data->output_tokens.clear(); diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index cb5696799..b658e0720 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "flutter/cpp/c/type.h" #include "pipeline.h" @@ -133,9 +134,9 @@ struct LLMBackendData { kv_cache_t kv_cache; std::vector prompt_tokens; std::vector output_tokens; - uint8_t threads = 30; - int max_output_tokens = 2048; - int stop_token_id = -1; + uint8_t threads = 2; + int max_output_tokens = 4; + std::unordered_set stop_token_ids{128001, 128008, 128009}; LLMBackendData() {} From 94f3cd539d46514670f93c83b28c39b5df64f85f Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 20 Oct 2025 03:34:29 +0300 Subject: [PATCH 35/74] properly format 0-shot prompts + allow for file/directory for model path --- flutter/cpp/datasets/mmlu_gen.cc | 17 +++++++++--- .../cpp/backend_tflite/llm_pipeline.cc | 27 +++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 47c2d1c73..302b14eb1 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -33,13 +33,22 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, std::string answer = tensorflow::GetFeatureValues("answer", example).Get(0); - if (zero_shot) - input = input.substr( + if (zero_shot) { + // input-formatted shots are separated by 2 new lines, so we find the last one which is the actual question + std::string question_formatted = input.substr( input.rfind("\n\n") + - 2); // input-formatted shots are separated by 2 new lines + 2); + // input-formatted starts with a preface followed by 2 new lines, we want that too. + std::string preface = input.substr(0, input.find("\n\n") + 2); + input = preface + "Question: " + question_formatted; + + LOG(INFO) << input; + } + + + //std::string input_formatted = FormatLlamaUserPrompt(input, "Provide only the answer letter, do not provide any explanation or preface."); - std::string input_formatted = FormatLlamaUserPrompt(input, "Provide only the answer letter, do not provide any explanation or preface."); std::vector input_tokens; sp_processor->Encode(input.c_str(), &input_tokens).ok(); diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 6d362f169..6437d24c4 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -55,10 +55,11 @@ mlperf_backend_ptr_t LLMPipeline::backend_create( LLMBackendData *backend_data = new LLMBackendData(); - std::string model_filename = mlperf::mobile::GetConfigValue( - configs, "model_filename", std::string("")); - - std::string llm_model_path = std::string(model_path) + '/' + model_filename; + std::string llm_model_path = std::string(model_path); + // Checking if the last section of the path doesn't have a file extension (indicates a directory is provided). + // Could be problematic when using hidden directories, in which case it would be best to provide a trailing slash. + if (llm_model_path.substr(llm_model_path.rfind('/')+1).find('.') == std::string::npos) + llm_model_path += '/' + mlperf::mobile::GetConfigValue(configs, "model_filename", std::string("")); // Load the model. backend_data->model = @@ -163,6 +164,18 @@ mlperf_status_t LLMPipeline::backend_issue_query( mlperf_backend_ptr_t backend_ptr, ft_callback callback, void *context) { LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + auto check_stop_id = [backend_data] (int id) { + for (int stop_token_id : backend_data->stop_token_ids) { + LOG(INFO) << std::to_string(id) << " -:- " << std::to_string(stop_token_id); + if (id == stop_token_id) { + LOG(INFO) << "BROKEN!"; + return true; + } + } + return false; + }; + + backend_issue_first_token_query(backend_ptr); callback(context); @@ -179,8 +192,7 @@ mlperf_status_t LLMPipeline::backend_issue_query( backend_data->output_tokens.reserve(decode_steps); int next_token = GreedySampler(backend_data->tensors.logits_output()); - for (int stop_token_id : backend_data->stop_token_ids) - if (next_token == stop_token_id) return MLPERF_SUCCESS; + if (check_stop_id(next_token)) return MLPERF_SUCCESS; backend_data->output_tokens.push_back(next_token); int next_position = input_size; for (int i = 0; i < decode_steps; ++i) { @@ -188,10 +200,9 @@ mlperf_status_t LLMPipeline::backend_issue_query( backend_data->tensors.decode_input_pos()->data.i32[0] = next_position; MINIMAL_CHECK(backend_data->decode_runner->Invoke() == kTfLiteOk); next_token = GreedySampler(backend_data->tensors.logits_output()); - for (int stop_token_id : backend_data->stop_token_ids) - if (next_token == stop_token_id) break; backend_data->output_tokens.push_back(next_token); next_position += 1; + if (check_stop_id(next_token)) break; } return MLPERF_SUCCESS; From e56d622e8bd4627482a76b2bdd126af596ffb430 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 20 Oct 2025 03:39:17 +0300 Subject: [PATCH 36/74] formatting --- flutter/cpp/datasets/ifeval.cc | 1 - flutter/cpp/datasets/mmlu_gen.cc | 18 +-- flutter/cpp/flutter/dart_run_benchmark.cc | 3 +- .../cpp/backend_tflite/llm_pipeline.cc | 106 +++++++++--------- .../cpp/backend_tflite/llm_pipeline.h | 2 +- 5 files changed, 64 insertions(+), 66 deletions(-) diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index 34124de2a..05b3eaab5 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -30,7 +30,6 @@ IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, tensorflow::GetFeatureValues("prompt", example).Get(0); auto instructions = BuildInstructions(example); - std::string input_formatted = FormatLlamaUserPrompt(prompt); std::vector input_tokens; sp_processor->Encode(input_formatted.c_str(), &input_tokens).ok(); diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 302b14eb1..8f999ba30 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -7,7 +7,6 @@ #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature_util.h" - #define li LOG(INFO) << "li:" << __FILE__ << ":" << __LINE__ << "@" << __func__ namespace mlperf { @@ -34,11 +33,11 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, tensorflow::GetFeatureValues("answer", example).Get(0); if (zero_shot) { - // input-formatted shots are separated by 2 new lines, so we find the last one which is the actual question - std::string question_formatted = input.substr( - input.rfind("\n\n") + - 2); - // input-formatted starts with a preface followed by 2 new lines, we want that too. + // input-formatted shots are separated by 2 new lines, so we find the last + // one which is the actual question + std::string question_formatted = input.substr(input.rfind("\n\n") + 2); + // input-formatted starts with a preface followed by 2 new lines, we want + // that too. std::string preface = input.substr(0, input.find("\n\n") + 2); input = preface + "Question: " + question_formatted; @@ -46,8 +45,8 @@ MmluGen::MmluGen(Backend* backend, const std::string& input_tfrecord, LOG(INFO) << input; } - - //std::string input_formatted = FormatLlamaUserPrompt(input, "Provide only the answer letter, do not provide any explanation or preface."); + // std::string input_formatted = FormatLlamaUserPrompt(input, "Provide only + // the answer letter, do not provide any explanation or preface."); std::vector input_tokens; sp_processor->Encode(input.c_str(), &input_tokens).ok(); @@ -101,7 +100,8 @@ std::vector MmluGen::ProcessOutput(const int sample_idx, used_sample_ids_.insert(sample_idx); li; - LOG(INFO) << "Processed " << std::to_string(used_sample_ids_.size()) << "/100"; + LOG(INFO) << "Processed " << std::to_string(used_sample_ids_.size()) + << "/100"; return {1}; } diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index fc4dd10b1..96df91ee3 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -110,7 +110,8 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( break; case ::mlperf::mobile::DatasetConfig::MMLU: for (auto setting : settings.benchmark_setting().custom_setting()) { - if (setting.id() == "tokenizer_filename") sp_path_filename = setting.value(); + if (setting.id() == "tokenizer_filename") + sp_path_filename = setting.value(); } sp_path = in->backend_model_path; sp_path += '/' + sp_path_filename; diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index 6437d24c4..ed6ea5d8e 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -38,28 +38,31 @@ static bool backendExists = false; // Destroy the backend pointer and its data. void LLMPipeline::backend_delete(mlperf_backend_ptr_t backend_ptr) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; if (backend_data) delete backend_data; backendExists = false; } // Create a new backend and return the pointer to it. mlperf_backend_ptr_t LLMPipeline::backend_create( - const char *model_path, mlperf_backend_configuration_t *configs, - const char *native_lib_path) { + const char* model_path, mlperf_backend_configuration_t* configs, + const char* native_lib_path) { // Verify only one instance of the backend exists at any time if (backendExists) { LOG(ERROR) << "Only one backend instance should exist at a time"; return nullptr; } - LLMBackendData *backend_data = new LLMBackendData(); + LLMBackendData* backend_data = new LLMBackendData(); std::string llm_model_path = std::string(model_path); - // Checking if the last section of the path doesn't have a file extension (indicates a directory is provided). - // Could be problematic when using hidden directories, in which case it would be best to provide a trailing slash. - if (llm_model_path.substr(llm_model_path.rfind('/')+1).find('.') == std::string::npos) - llm_model_path += '/' + mlperf::mobile::GetConfigValue(configs, "model_filename", std::string("")); + // Checking if the last section of the path doesn't have a file extension + // (indicates a directory is provided). Could be problematic when using hidden + // directories, in which case it would be best to provide a trailing slash. + if (llm_model_path.substr(llm_model_path.rfind('/') + 1).find('.') == + std::string::npos) + llm_model_path += '/' + mlperf::mobile::GetConfigValue( + configs, "model_filename", std::string("")); // Load the model. backend_data->model = @@ -88,20 +91,20 @@ mlperf_backend_ptr_t LLMPipeline::backend_create( } // Vendor name who create this backend. -const char *LLMPipeline::backend_vendor_name(mlperf_backend_ptr_t backend_ptr) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; +const char* LLMPipeline::backend_vendor_name(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; return backend_data->vendor; } -const char *LLMPipeline::backend_accelerator_name( +const char* LLMPipeline::backend_accelerator_name( mlperf_backend_ptr_t backend_ptr) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; return backend_data->accelerator; } // Return the name of this backend. -const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; +const char* LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; return backend_data->name; } @@ -109,7 +112,7 @@ const char *LLMPipeline::backend_name(mlperf_backend_ptr_t backend_ptr) { // inference. This function exclusively handles the input tokens. mlperf_status_t LLMPipeline::backend_issue_first_token_query( mlperf_backend_ptr_t backend_ptr) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; int max_seq_size = backend_data->tensors.prefill_input()->dims->data[1]; int kv_cache_max_size = backend_data->tensors.kv_cache_k_0()->dims->data[1]; @@ -161,21 +164,16 @@ mlperf_status_t LLMPipeline::backend_issue_first_token_query( // Run the output token producing decode inference. // This function exclusively takes output tokens to produce more output tokens. mlperf_status_t LLMPipeline::backend_issue_query( - mlperf_backend_ptr_t backend_ptr, ft_callback callback, void *context) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + mlperf_backend_ptr_t backend_ptr, ft_callback callback, void* context) { + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; - auto check_stop_id = [backend_data] (int id) { + auto check_stop_id = [backend_data](int id) { for (int stop_token_id : backend_data->stop_token_ids) { - LOG(INFO) << std::to_string(id) << " -:- " << std::to_string(stop_token_id); - if (id == stop_token_id) { - LOG(INFO) << "BROKEN!"; - return true; - } + if (id == stop_token_id) return true; } return false; }; - backend_issue_first_token_query(backend_ptr); callback(context); @@ -230,16 +228,16 @@ mlperf_data_t LLMPipeline::backend_get_input_type( // Set the data for ith input. mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, int32_t batch_index, int32_t i, - void *data) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + void* data) { + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; // Reset the tokens and kv caches from potential previous runs. backend_data->output_tokens.clear(); - for (auto &[_, vec] : backend_data->kv_cache) { + for (auto& [_, vec] : backend_data->kv_cache) { std::fill(vec.begin(), vec.end(), 0.0f); } - backend_data->prompt_tokens = *(reinterpret_cast *>(data)); + backend_data->prompt_tokens = *(reinterpret_cast*>(data)); uint16_t effective_prefill_token_size = backend_data->prompt_tokens.size() - 1; // assuming max tokens is <16k @@ -281,29 +279,29 @@ mlperf_data_t LLMPipeline::backend_get_output_type( // Get the data from ith output. mlperf_status_t LLMPipeline::backend_get_output( mlperf_backend_ptr_t backend_ptr, uint32_t batch_index, int32_t i, - void **data) { - LLMBackendData *backend_data = (LLMBackendData *)backend_ptr; + void** data) { + LLMBackendData* backend_data = (LLMBackendData*)backend_ptr; if (i != 0) return MLPERF_FAILURE; - *data = reinterpret_cast(&backend_data->output_tokens); + *data = reinterpret_cast(&backend_data->output_tokens); return MLPERF_SUCCESS; } void LLMPipeline::backend_convert_inputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, - uint8_t *data) {} + uint8_t* data) {} void LLMPipeline::backend_convert_outputs(mlperf_backend_ptr_t backend_ptr, int bytes, int width, int height, - uint8_t *data) {} + uint8_t* data) {} -void *LLMPipeline::backend_get_buffer(size_t n) { return ::operator new(n); } +void* LLMPipeline::backend_get_buffer(size_t n) { return ::operator new(n); } -void LLMPipeline::backend_release_buffer(void *p) { ::operator delete(p); } +void LLMPipeline::backend_release_buffer(void* p) { ::operator delete(p); } -tflite::Interpreter *LLMPipeline::BuildInterpreter( - tflite::FlatBufferModel *model, int num_threads) { +tflite::Interpreter* LLMPipeline::BuildInterpreter( + tflite::FlatBufferModel* model, int num_threads) { tflite::ops::builtin::BuiltinOpResolver resolver; // NOTE: We need to manually register optimized OPs for KV-cache and // Scaled Dot Product Attention (SDPA). @@ -318,8 +316,8 @@ tflite::Interpreter *LLMPipeline::BuildInterpreter( return interpreter.release(); } -kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter *interpreter) { - tflite::SignatureRunner *runner = interpreter->GetSignatureRunner("decode"); +kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter* interpreter) { + tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); if (runner == nullptr) { return {}; } @@ -335,7 +333,7 @@ kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter *interpreter) { std::string k_cache_name = "kv_cache_k_" + std::to_string(i); std::string v_cache_name = "kv_cache_v_" + std::to_string(i); // We are assuming K and V tensors are of the same shape. - TfLiteTensor *tensor = runner->input_tensor(k_cache_name.c_str()); + TfLiteTensor* tensor = runner->input_tensor(k_cache_name.c_str()); size_t count = tensor->bytes / sizeof(float); kv_cache.emplace(k_cache_name, std::vector>(count, 0.0f)); @@ -346,11 +344,11 @@ kv_cache_t LLMPipeline::BuildKVCache(tflite::Interpreter *interpreter) { return kv_cache; } -void LLMPipeline::PrepareRunner(tflite::SignatureRunner *runner, - kv_cache_t &kv_cache) { - for (auto &[name, cache] : kv_cache) { +void LLMPipeline::PrepareRunner(tflite::SignatureRunner* runner, + kv_cache_t& kv_cache) { + for (auto& [name, cache] : kv_cache) { TfLiteCustomAllocation allocation = {}; - allocation.data = static_cast(cache.data()); + allocation.data = static_cast(cache.data()); allocation.bytes = cache.size() * sizeof(float); // Both input and output tensors are set to the same buffer. Not all // delegates support this in-place update. For those cases, we need to do @@ -363,18 +361,18 @@ void LLMPipeline::PrepareRunner(tflite::SignatureRunner *runner, MINIMAL_CHECK_VOID(runner->AllocateTensors() == kTfLiteOk); } -tflite::SignatureRunner *LLMPipeline::GetPrefillRunner( - tflite::Interpreter *interpreter, std::size_t num_input_tokens, - kv_cache_t &kv_cache) { +tflite::SignatureRunner* LLMPipeline::GetPrefillRunner( + tflite::Interpreter* interpreter, std::size_t num_input_tokens, + kv_cache_t& kv_cache) { // Find the prefill signature length that best matches the input token size. - tflite::SignatureRunner *runner = nullptr; + tflite::SignatureRunner* runner = nullptr; // int best_seq_size = -1; size_t delta = std::numeric_limits::max(); size_t max_prefill_size = 0; std::string max_prefill_key = std::string(""); - for (const std::string *key : interpreter->signature_keys()) { + for (const std::string* key : interpreter->signature_keys()) { if (key->find("prefill") == std::string::npos) continue; - TfLiteTensor *input_pos = interpreter->GetSignatureRunner(key->c_str()) + TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str()) ->input_tensor("input_pos"); // The expected shape for input position is [Seq]. size_t seq_size = input_pos->dims->data[0]; @@ -398,16 +396,16 @@ tflite::SignatureRunner *LLMPipeline::GetPrefillRunner( return runner; } -tflite::SignatureRunner *LLMPipeline::GetDecodeRunner( - tflite::Interpreter *interpreter, kv_cache_t &kv_cache) { - tflite::SignatureRunner *runner = interpreter->GetSignatureRunner("decode"); +tflite::SignatureRunner* LLMPipeline::GetDecodeRunner( + tflite::Interpreter* interpreter, kv_cache_t& kv_cache) { + tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode"); MINIMAL_CHECK_PTR(runner != nullptr); PrepareRunner(runner, kv_cache); return runner; } // A basic greedy sampler (equivalent to argmax). -int LLMPipeline::GreedySampler(const TfLiteTensor *logits) { +int LLMPipeline::GreedySampler(const TfLiteTensor* logits) { float max_value = -std::numeric_limits::infinity(); int max_index = 0; // logits shape: [Batch, Seq, Vocab], Dtype: float diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index b658e0720..8dd5181c0 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include #include +#include #include "flutter/cpp/c/type.h" #include "pipeline.h" From 9120d633480a640e6c2fb0bbef66e3aa3b15aeb4 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Mon, 27 Oct 2025 09:21:30 +0300 Subject: [PATCH 37/74] potential fix for windows C2440 --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index f086b5733..25bd117e0 100644 --- a/.bazelrc +++ b/.bazelrc @@ -109,7 +109,7 @@ build:windows --host_linkopt=/OPT:ICF # MSVC does not support XNNPACK AVXVNNI instructions (causes C2440 error). build:windows --define=xnn_enable_avxvnni=false -#build:windows --define=xnn_enable_avxvnniint8=false +build:windows --define=xnn_enable_avxvnniint8=false # Address sanitizer build:asan --strip=never From 002d2d05e3d03224fdd455e78cd40cbdb6a14197 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 05:50:58 +0300 Subject: [PATCH 38/74] fix for aligned free for windows --- .../cpp/backend_tflite/llm_pipeline.h | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 8dd5181c0..58bb24377 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -20,6 +20,11 @@ limitations under the License. #include #include #include +#include + +#if defined(_MSC_VER) +#include +#endif #include "flutter/cpp/c/type.h" #include "pipeline.h" @@ -63,14 +68,25 @@ class AlignedAllocator { // std::size_t padding = tflite::kDefaultTensorAlignment - // (size % tflite::kDefaultTensorAlignment); // size += padding; + +#if defined(_MSC_VER) + ptr = _aligned_malloc(size tflite::kDefaultTensorAlignment); +#else int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); if (ret != 0) { return nullptr; } +#endif return static_cast(ptr); }; - void deallocate(T *ptr, std::size_t n) { free(ptr); } + void deallocate(T *ptr, std::size_t n) { +#if defined(_MSC_VER) + _aligned_free(ptr); +#else + free(ptr); +#endif + } }; using kv_cache_t = From 9f81bddbb4cd8358a7f44e90ab2a5cedafb087e5 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 06:22:09 +0300 Subject: [PATCH 39/74] potential fix for IOS / windows CI issues --- .bazelrc | 2 ++ mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index 25bd117e0..8b238a704 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,6 +79,8 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc +build:ios --copt=-DEIGEN_EXCEPTIONS=0 +build:ios --copt=-DEIGEN_NO_EXCEPTIONS # Windows configs diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 58bb24377..8f55f58f0 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -70,7 +70,7 @@ class AlignedAllocator { // size += padding; #if defined(_MSC_VER) - ptr = _aligned_malloc(size tflite::kDefaultTensorAlignment); + ptr = _aligned_malloc(size, tflite::kDefaultTensorAlignment); #else int ret = posix_memalign(&ptr, tflite::kDefaultTensorAlignment, size); if (ret != 0) { @@ -151,7 +151,7 @@ struct LLMBackendData { std::vector prompt_tokens; std::vector output_tokens; uint8_t threads = 2; - int max_output_tokens = 4; + int max_output_tokens = 1024; std::unordered_set stop_token_ids{128001, 128008, 128009}; LLMBackendData() {} From 93d53529507a29f190384b30f5fd7ddb563b3d09 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 06:26:42 +0300 Subject: [PATCH 40/74] ifeval check cleanup and bugfixes --- flutter/assets/tasks.pbtxt | 4 + flutter/cpp/binary/main.cc | 6 +- flutter/cpp/datasets/ifeval.cc | 153 ++++++++----------- flutter/cpp/datasets/ifeval.h | 35 ++--- flutter/cpp/datasets/ifeval_utils/common.h | 30 +++- flutter/cpp/datasets/ifeval_utils/types.h | 165 ++++++++++++++------- 6 files changed, 230 insertions(+), 163 deletions(-) diff --git a/flutter/assets/tasks.pbtxt b/flutter/assets/tasks.pbtxt index 51acd52df..06763f361 100644 --- a/flutter/assets/tasks.pbtxt +++ b/flutter/assets/tasks.pbtxt @@ -387,6 +387,10 @@ task { id: "LLM" name: "LLM" } + custom_config { + id: "llm_tokenizer_path" + value: "llama3_1b.spm.model" + } } task { diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index d8510197c..397cfcb4d 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -426,7 +426,6 @@ int Main(int argc, char *argv[]) { dataset_flags.end()); } break; case DatasetConfig::IFEVAL: { - bool loose_follow = false; LOG(INFO) << "IFEval dataset for LLM benchmark"; std::string input_tfrecord, sp_path = ""; std::vector dataset_flags{ @@ -437,15 +436,12 @@ int Main(int argc, char *argv[]) { Flag::CreateFlag("sp_path", &sp_path, "Path to the sentencepiece model file.", Flag::kRequired), - Flag::CreateFlag("loose-follow", &loose_follow, - "Whether to loosely check if the instructions are " - "being followed"), }; if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && backend) { dataset.reset( - new IFEval(backend.get(), input_tfrecord, sp_path, loose_follow)); + new IFEval(backend.get(), input_tfrecord, sp_path)); } // Adds to flag_list for showing help. flag_list.insert(flag_list.end(), dataset_flags.begin(), diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index 05b3eaab5..7f6f810e9 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -11,10 +11,8 @@ namespace mlperf { namespace mobile { IFEval::IFEval(Backend* backend, const std::string& input_tfrecord, - const std::string& sp_path, bool loose_follow) - : sample_reader_(input_tfrecord), - loose_follow_(loose_follow), - Dataset(backend) { + const std::string& sp_path) + : sample_reader_(input_tfrecord), Dataset(backend) { sp_processor = std::unique_ptr( LoadSentencePieceProcessor(sp_path)); @@ -76,6 +74,10 @@ std::vector IFEval::ProcessOutput(const int sample_idx, *(reinterpret_cast*>(outputs[0])); sample_output_tokens_[sample_idx] = output_tokens; + used_sample_ids_.insert(sample_idx); + + LOG(INFO) << "Processed " << std::to_string(used_sample_ids_.size()) + << "/29"; return {1}; } @@ -87,49 +89,77 @@ int64_t IFEval::GetOutputTokenCount(const int sample_idx) { bool IFEval::HasAccuracy() { return true; } bool IFEval::ComputeSampleAccuracy(const int sample_idx, - ifeval::GroupAccuracy& accuracy) { + ifeval::Accuracy& accuracy) { + std::string prediction; sp_processor->Decode(sample_output_tokens_[sample_idx], &prediction).ok(); - LOG(INFO) << "output(" << std::to_string(sample_idx) << "): " << prediction - << std::endl; - + bool is_prompt_correct_loose = true; + bool is_prompt_correct_strict = true; for (const auto& instruction : samples_[sample_idx]->instructions) { - bool is_correct = instruction->IsFollowed(prediction, loose_follow_); - ProcessResult(instruction->Group(), is_correct, accuracy); + + bool is_correct_loose = instruction->IsFollowed(prediction, true); + bool is_correct_strict = instruction->IsFollowed(prediction, false); + + accuracy.instruction_total++; + accuracy.instruction_correct_loose += is_correct_loose ? 1 : 0; + accuracy.instruction_correct_strict += is_correct_strict ? 1 : 0; + + is_prompt_correct_loose = is_prompt_correct_loose ? is_correct_loose : false; + is_prompt_correct_strict = is_prompt_correct_strict ? is_correct_strict : false; } + + accuracy.prompt_total++; + accuracy.prompt_correct_loose += is_prompt_correct_loose ? 1 : 0; + accuracy.prompt_correct_strict += is_prompt_correct_strict ? 1 : 0; + + return true; } float IFEval::ComputeAccuracy() { - uint16_t correct_sum; - uint16_t total_sum; - ifeval::GroupAccuracy accuracy; + float instruction_loose_accuracy; + float instruction_strict_accuracy; + float prompt_loose_accuracy; + float prompt_strict_accuracy; + ifeval::Accuracy accuracy; for (auto sample_id : used_sample_ids_) { ComputeSampleAccuracy(sample_id, accuracy); } - correct_sum += accuracy.change_case_correct; - correct_sum += accuracy.combination_correct; - correct_sum += accuracy.detectable_content_correct; - correct_sum += accuracy.detectable_format_correct; - correct_sum += accuracy.keywords_correct; - correct_sum += accuracy.language_correct; - correct_sum += accuracy.length_constraints_correct; - correct_sum += accuracy.punctuation_correct; - correct_sum += accuracy.startend_correct; - - total_sum += accuracy.change_case_total; - total_sum += accuracy.combination_total; - total_sum += accuracy.detectable_content_total; - total_sum += accuracy.detectable_format_total; - total_sum += accuracy.keywords_total; - total_sum += accuracy.language_total; - total_sum += accuracy.length_constraints_total; - total_sum += accuracy.punctuation_total; - total_sum += accuracy.startend_total; - - return total_sum > 0 ? static_cast(correct_sum) / total_sum : 0.0f; + instruction_loose_accuracy = + accuracy.instruction_total > 0 + ? static_cast(accuracy.instruction_correct_loose) / + accuracy.instruction_total + : 0.0f; + instruction_strict_accuracy = + accuracy.instruction_total > 0 + ? static_cast(accuracy.instruction_correct_strict) / + accuracy.instruction_total + : 0.0f; + prompt_loose_accuracy = + accuracy.prompt_total > 0 + ? static_cast(accuracy.prompt_correct_loose) / + accuracy.prompt_total + : 0.0f; + prompt_strict_accuracy = + accuracy.prompt_total > 0 + ? static_cast(accuracy.prompt_correct_strict) / + accuracy.prompt_total + : 0.0f; + + LOG(INFO) << "Instruction-level loose-accuracy: " + << std::to_string(instruction_loose_accuracy); + LOG(INFO) << "Instruction-level strict-accuracy: " + << std::to_string(instruction_strict_accuracy); + LOG(INFO) << "Prompt-level loose-accuracy: " + << std::to_string(prompt_loose_accuracy); + LOG(INFO) << "Prompt-level strict-accuracy: " + << std::to_string(prompt_strict_accuracy); + + return (instruction_loose_accuracy + instruction_strict_accuracy + + prompt_loose_accuracy + prompt_strict_accuracy) / + 4.0f; } std::string IFEval::ComputeAccuracyString() { @@ -421,60 +451,5 @@ IFEval::BuildInstructions(const tensorflow::Example& ex) { return out; } -inline void IFEval::ProcessResult(ifeval::InstructionGroup group, - bool is_correct, - ifeval::GroupAccuracy& accuracy) { - uint8_t correct_value = is_correct ? 1 : 0; - switch (group) { - case ifeval::InstructionGroup::CHANGE_CASE: - accuracy.change_case_correct += correct_value; - accuracy.change_case_total++; - break; - - case ifeval::InstructionGroup::COMBINATION: - accuracy.combination_correct += correct_value; - accuracy.combination_total++; - break; - - case ifeval::InstructionGroup::DETECTABLE_CONTENT: - accuracy.detectable_content_correct += correct_value; - accuracy.detectable_content_total++; - break; - - case ifeval::InstructionGroup::DETECTABLE_FORMAT: - accuracy.detectable_format_correct += correct_value; - accuracy.detectable_format_total++; - break; - - case ifeval::InstructionGroup::KEYWORDS: - accuracy.keywords_correct += correct_value; - accuracy.keywords_total++; - break; - - case ifeval::InstructionGroup::LANGUAGE: - accuracy.language_correct += correct_value; - accuracy.language_total++; - break; - - case ifeval::InstructionGroup::LENGTH_CONSTRAINTS: - accuracy.length_constraints_correct += correct_value; - accuracy.length_constraints_total++; - break; - - case ifeval::InstructionGroup::PUNCTUATION: - accuracy.punctuation_correct += correct_value; - accuracy.punctuation_total++; - break; - - case ifeval::InstructionGroup::STARTEND: - accuracy.startend_correct += correct_value; - accuracy.startend_total++; - break; - - default: - break; - } -} - } // namespace mobile } // namespace mlperf diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h index 9d07fef27..a8eaf63ce 100644 --- a/flutter/cpp/datasets/ifeval.h +++ b/flutter/cpp/datasets/ifeval.h @@ -19,22 +19,28 @@ namespace mlperf { namespace mobile { namespace ifeval { -struct GroupAccuracy { - size_t change_case_correct = 0, combination_correct = 0, - detectable_content_correct = 0, detectable_format_correct = 0, - keywords_correct = 0, language_correct = 0, - length_constraints_correct = 0, punctuation_correct = 0, - startend_correct = 0; - size_t change_case_total = 0, combination_total = 0, - detectable_content_total = 0, detectable_format_total = 0, - keywords_total = 0, language_total = 0, length_constraints_total = 0, - punctuation_total = 0, startend_total = 0; +// struct GroupAccuracy { +// size_t change_case_correct = 0, combination_correct = 0, +// detectable_content_correct = 0, detectable_format_correct = 0, +// keywords_correct = 0, language_correct = 0, +// length_constraints_correct = 0, punctuation_correct = 0, +// startend_correct = 0; +// size_t change_case_total = 0, combination_total = 0, +// detectable_content_total = 0, detectable_format_total = 0, +// keywords_total = 0, language_total = 0, length_constraints_total = +// 0, punctuation_total = 0, startend_total = 0; +// }; + +struct Accuracy { + size_t prompt_correct_loose = 0, prompt_correct_strict = 0, prompt_total = 0, + instruction_correct_loose = 0, instruction_correct_strict = 0, + instruction_total = 0; }; } // namespace ifeval class IFEval : public Dataset { public: IFEval(Backend* backend, const std::string& input_tfrecord, - const std::string& sp_path, bool loose_follow); + const std::string& sp_path); const std::string& Name() override { return name_; } @@ -57,7 +63,7 @@ class IFEval : public Dataset { bool HasAccuracy() override; bool ComputeSampleAccuracy(const int sample_idx, - ifeval::GroupAccuracy& accuracy); + ifeval::Accuracy& accuracy); float ComputeAccuracy() override; @@ -66,9 +72,6 @@ class IFEval : public Dataset { inline std::vector> BuildInstructions( const tensorflow::Example& ex); - inline void ProcessResult(ifeval::InstructionGroup group, bool is_correct, - ifeval::GroupAccuracy& accuracy); - private: const std::string name_ = "IFEval"; @@ -79,8 +82,6 @@ class IFEval : public Dataset { std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; - - bool loose_follow_; }; } // namespace mobile diff --git a/flutter/cpp/datasets/ifeval_utils/common.h b/flutter/cpp/datasets/ifeval_utils/common.h index f1bcdfe03..0345d0fbf 100644 --- a/flutter/cpp/datasets/ifeval_utils/common.h +++ b/flutter/cpp/datasets/ifeval_utils/common.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -48,8 +49,33 @@ inline bool contains_string(const std::string& text, } inline bool contains_word(const std::string& text, const std::string& word) { - std::regex rx("\\b" + word + "\\b", std::regex::icase); - return std::regex_search(text.begin(), text.end(), rx); + if (word.empty()) return false; + + LOG(INFO) << "searching for '" << word << "'..."; + + auto to_lower_ascii = [](std::string s) { + for (char& c : s) c = std::tolower(static_cast(c)); + return s; + }; + auto is_word_char = [](unsigned char c) { + return std::isalnum(c) || c == '_'; // match std::regex \b notion of "word" + }; + + std::string t = to_lower_ascii(text); + std::string w = to_lower_ascii(word); + + // Scan all occurrences of w in t and check word boundaries + std::size_t pos = 0; + while ((pos = t.find(w, pos)) != std::string::npos) { + const bool left_ok = + (pos == 0) || !is_word_char(static_cast(t[pos - 1])); + const std::size_t end = pos + w.size(); + const bool right_ok = + (end == t.size()) || !is_word_char(static_cast(t[end])); + if (left_ok && right_ok) return true; + ++pos; // continue searching (overlapping-safe) + } + return false; } inline bool contains_none(const std::string& text, diff --git a/flutter/cpp/datasets/ifeval_utils/types.h b/flutter/cpp/datasets/ifeval_utils/types.h index cffce43ea..e90f8c8a7 100644 --- a/flutter/cpp/datasets/ifeval_utils/types.h +++ b/flutter/cpp/datasets/ifeval_utils/types.h @@ -44,7 +44,7 @@ class Instruction { auto transformations = transform_response(resp); for (std::string transformation : transformations) { - if (verify_(resp)) return true; + if (verify_(transformation)) return true; } return false; } @@ -66,23 +66,42 @@ class CapitalWordFrequency : public Instruction { int threshold_; Relation rel_; - static size_t CapitalWords(const std::string& resp) { - size_t words = 0; + static bool IsAllCapsToken(std::string_view t) { + // trim leading/trailing punctuation (keep '-' and '\'' because they appear + // inside words) + auto is_trim = [](unsigned char c) { + return !(std::isalnum(c) || c == '-' || c == '\''); + }; + size_t b = 0, e = t.size(); + while (b < e && is_trim((unsigned char)t[b])) ++b; + while (e > b && is_trim((unsigned char)t[e - 1])) --e; + if (b >= e) return false; + + bool seen_alpha = false; + for (size_t i = b; i < e; ++i) { + unsigned char c = (unsigned char)t[i]; + if (std::isalpha(c)) { + seen_alpha = true; + if (std::islower(c)) + return false; // any lowercase letter breaks ALL-CAPS + } + // digits, '-', '\'' are allowed and ignored for casing + } + return seen_alpha; // at least one letter, and no lowercase letters + } + + static size_t CountAllCapsWords(const std::string& resp) { + size_t count = 0; std::istringstream is(resp); - std::string w; - while (is >> w) { - size_t i = 0; - while (i < w.size() && !std::isalnum((unsigned char)w[i]) && - !std::isupper((unsigned char)w[i])) - ++i; - if (i >= w.size()) continue; - ++words; + std::string tok; + while (is >> tok) { + if (IsAllCapsToken(tok)) ++count; } - return words; + return count; } - bool verify_(const std::string& resp) const override { - size_t words = CapitalWords(resp); + virtual bool verify_(const std::string& resp) const override { + size_t words = CountAllCapsWords(resp); return compare(words, threshold_, rel_); } }; @@ -93,9 +112,10 @@ class EnglishCapital : public Instruction { constexpr InstructionGroup Group() override { return CHANGE_CASE; } private: - bool verify_(const std::string& resp) const override { - return std::all_of(resp.begin(), resp.end(), - [](unsigned char c) { return std::isupper(c); }); + virtual bool verify_(const std::string& resp) const override { + return std::all_of(resp.begin(), resp.end(), [](unsigned char c) { + return !std::isalpha(c) || std::isupper(c); + }); } }; @@ -105,9 +125,10 @@ class EnglishLowercase : public Instruction { constexpr InstructionGroup Group() override { return CHANGE_CASE; } private: - bool verify_(const std::string& resp) const override { - return std::all_of(resp.begin(), resp.end(), - [](unsigned char c) { return std::islower(c); }); + virtual bool verify_(const std::string& resp) const override { + return std::all_of(resp.begin(), resp.end(), [](unsigned char c) { + return !std::isalpha(c) || std::islower(c); + }); } }; @@ -121,7 +142,7 @@ class RepeatPrompt : public Instruction { private: std::string prompt_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { // TODO replace with startswith? return contains_string(resp, prompt_); } @@ -133,7 +154,7 @@ class TwoResponses : public Instruction { constexpr InstructionGroup Group() override { return COMBINATION; } private: - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { std::size_t count = 0; std::size_t pos = resp.find("******"); while (pos != std::string::npos) { @@ -153,7 +174,7 @@ class NumberPlaceholders : public Instruction { private: int n_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { std::size_t count = 0, pos = 0; while (pos < resp.length() && (int)count < n_) { // no need to keep looking if the requirement is @@ -188,7 +209,7 @@ class Postscript : public Instruction { private: std::string marker_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { return contains_string(resp, marker_); } }; @@ -201,7 +222,7 @@ class ConstrainedResponse : public Instruction { constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } private: - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { return resp == "My answer is yes." || resp == "My answer is no." || resp == "My answer is maybe."; } @@ -214,7 +235,7 @@ class JsonFormat : public Instruction { private: // TODO possibly use a C++ json validator instead - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { std::string t = resp; if (t.empty()) return false; if (!((t.front() == '{' && t.back() == '}') || @@ -281,7 +302,7 @@ class MultipleSections : public Instruction { } return parts; } - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { auto parts = SplitByDelim(resp, sep_); return CountNonEmpty(parts) == n_; } @@ -303,7 +324,7 @@ class NumberBulletLists : public Instruction { return out; } - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { size_t count = 0; for (const auto& line : SplitLines(resp)) { std::string t = trim(line); @@ -323,7 +344,7 @@ class NumberHighlightedSections : public Instruction { private: int n_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { std::size_t count = 0; std::size_t pos = 0; @@ -363,7 +384,7 @@ class Title : public Instruction { constexpr InstructionGroup Group() override { return DETECTABLE_FORMAT; } private: - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { std::size_t pos_open = resp.find("<<"); // TODO should an empty title be allowed? return (pos_open != std::string::npos) && @@ -382,7 +403,7 @@ class Existence : public Instruction { private: std::vector kws_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { for (const auto& k : kws_) if (!contains_word(resp, k)) return false; return true; @@ -397,7 +418,7 @@ class ForbiddenWords : public Instruction { private: std::vector bad_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { return contains_none(resp, bad_); } }; @@ -412,12 +433,56 @@ class Frequency : public Instruction { int n_; std::string kw_; Relation rel_; - bool verify_(const std::string& resp) const override { - std::regex rx("\\b" + kw_ + "\\b", std::regex::icase); - size_t count = 0; - auto it = std::sregex_iterator(resp.begin(), resp.end(), rx); - auto end = std::sregex_iterator(); - for (; it != end; ++it) ++count; + + static inline std::string RegexEscape(const std::string& s) { + auto is_meta = [](unsigned char ch) { + switch (ch) { + case '^': case '$': case '.': case '|': case '?': + case '*': case '+': case '(': case ')': + case '[': case ']': case '{': case '}': case '\\': + return true; + default: + return false; + } + }; + + std::string out; + out.reserve(s.size() * 2); + for (unsigned char c : s) { + if (is_meta(c)) out.push_back('\\'); + out.push_back(static_cast(c)); + } + return out; + } + + // Build a regex that matches the keyword with custom token boundaries. + // Left boundary is (^|[^A-Za-z0-9_]) to avoid lookbehind. + // Right boundary uses a lookahead (?=$|[^A-Za-z0-9_]). + static inline std::regex MakeKeywordRegex(const std::string& keyword) { + const std::string kw = RegexEscape(keyword); + const std::string pat = + "(^|[^A-Za-z0-9_])" // left boundary (consumes 1 char or start) + "(?:" + + kw + + ")" // keyword literal + "(?=$|[^A-Za-z0-9_])"; // right boundary (zero-width lookahead) + return std::regex(pat, std::regex::icase); + } + + static inline std::size_t CountKeywordOccurrences( + const std::string& text, const std::string& keyword) { + const std::regex rx = MakeKeywordRegex(keyword); + std::size_t count = 0; + for (auto it = std::sregex_iterator(text.begin(), text.end(), rx), + end = std::sregex_iterator(); + it != end; ++it) { + ++count; + } + return count; + } + + virtual bool verify_(const std::string& resp) const override { + const std::size_t count = CountKeywordOccurrences(resp, kw_); return compare(count, (size_t)n_, rel_); } }; @@ -439,7 +504,7 @@ class LetterFrequency : public Instruction { if (std::tolower(ch) == lower) ++c; return c; } - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { size_t c = CountLetterICase(resp, letter_); return compare(c, (size_t)n_, rel_); } @@ -514,7 +579,7 @@ class ResponseLanguage : public Instruction { return non_ascii_ratio() > 0.05; } - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { return LanguageHeuristic(resp, lang_); } }; @@ -560,7 +625,7 @@ class NthParagraphFirstWord : public Instruction { return paras; } - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { auto paras = SplitParagraphs(resp); if ((int)paras.size() != total_) return false; if (nth_ <= 0 || nth_ > (int)paras.size()) return false; @@ -577,13 +642,13 @@ class NumberParagraphs : public Instruction { private: unsigned n_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { std::size_t count = 0, pos = 0; - while ((pos = resp.find("***", pos)) != std::string::npos) { + while ((pos = resp.find("***\n", pos)) != std::string::npos) { ++count; - pos += 3; // advance by 3 for non-overlapping matches + pos += 4; // advance by 3 for non-overlapping matches } - return count + 1 == n_; // since *** is a saparator, the actual count is 1 + return count == n_ - 1; // since *** is a saparator, the actual count is 1 // more than the number of separators } }; @@ -597,7 +662,7 @@ class NumberSentences : public Instruction { private: int n_; Relation rel_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { size_t count = 0; for (unsigned char c : resp) { if (c == '.' || c == '!' || c == '?') ++count; @@ -615,7 +680,7 @@ class NumberWords : public Instruction { private: int n_; Relation rel_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { size_t count = 0; bool in_word = false; for (unsigned char c : resp) { @@ -639,7 +704,7 @@ class NoComma : public Instruction { constexpr InstructionGroup Group() override { return PUNCTUATION; } private: - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { return resp.find(',') == std::string::npos; } }; @@ -653,7 +718,7 @@ class EndChecker : public Instruction { private: std::string end_; - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { return ends_with(resp, end_); } }; @@ -664,7 +729,7 @@ class Quotation : public Instruction { constexpr InstructionGroup Group() override { return STARTEND; } private: - bool verify_(const std::string& resp) const override { + virtual bool verify_(const std::string& resp) const override { if (resp.size() < 2) return false; return resp.front() == '"' && resp.back() == '"'; } From fc0f2415529ba5269238ab43f80fef0d0943cbfa Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 06:51:31 +0300 Subject: [PATCH 41/74] formatting --- flutter/cpp/binary/main.cc | 3 +-- flutter/cpp/datasets/ifeval.cc | 13 +++++++------ flutter/cpp/datasets/ifeval.h | 3 +-- flutter/cpp/datasets/ifeval_utils/types.h | 17 ++++++++++++++--- .../cpp/backend_tflite/llm_pipeline.h | 3 ++- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/flutter/cpp/binary/main.cc b/flutter/cpp/binary/main.cc index 397cfcb4d..e962d1b9d 100644 --- a/flutter/cpp/binary/main.cc +++ b/flutter/cpp/binary/main.cc @@ -440,8 +440,7 @@ int Main(int argc, char *argv[]) { if (Flags::Parse(&argc, const_cast(argv), dataset_flags) && backend) { - dataset.reset( - new IFEval(backend.get(), input_tfrecord, sp_path)); + dataset.reset(new IFEval(backend.get(), input_tfrecord, sp_path)); } // Adds to flag_list for showing help. flag_list.insert(flag_list.end(), dataset_flags.begin(), diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index 7f6f810e9..9328c2272 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -76,8 +76,7 @@ std::vector IFEval::ProcessOutput(const int sample_idx, sample_output_tokens_[sample_idx] = output_tokens; used_sample_ids_.insert(sample_idx); - LOG(INFO) << "Processed " << std::to_string(used_sample_ids_.size()) - << "/29"; + LOG(INFO) << "Processed " << std::to_string(used_sample_ids_.size()) << "/29"; return {1}; } @@ -90,14 +89,12 @@ bool IFEval::HasAccuracy() { return true; } bool IFEval::ComputeSampleAccuracy(const int sample_idx, ifeval::Accuracy& accuracy) { - std::string prediction; sp_processor->Decode(sample_output_tokens_[sample_idx], &prediction).ok(); bool is_prompt_correct_loose = true; bool is_prompt_correct_strict = true; for (const auto& instruction : samples_[sample_idx]->instructions) { - bool is_correct_loose = instruction->IsFollowed(prediction, true); bool is_correct_strict = instruction->IsFollowed(prediction, false); @@ -105,8 +102,10 @@ bool IFEval::ComputeSampleAccuracy(const int sample_idx, accuracy.instruction_correct_loose += is_correct_loose ? 1 : 0; accuracy.instruction_correct_strict += is_correct_strict ? 1 : 0; - is_prompt_correct_loose = is_prompt_correct_loose ? is_correct_loose : false; - is_prompt_correct_strict = is_prompt_correct_strict ? is_correct_strict : false; + is_prompt_correct_loose = + is_prompt_correct_loose ? is_correct_loose : false; + is_prompt_correct_strict = + is_prompt_correct_strict ? is_correct_strict : false; } accuracy.prompt_total++; @@ -182,7 +181,9 @@ IFEval::BuildInstructions(const tensorflow::Example& ex) { auto get_strs = [&](const std::string& key, std::vector* vals) -> bool { const auto& sfield = tensorflow::GetFeatureValues(key, ex); + std::vector svals(sfield.begin(), sfield.end()); + *vals = std::move(svals); return true; }; diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h index a8eaf63ce..2dd1e4747 100644 --- a/flutter/cpp/datasets/ifeval.h +++ b/flutter/cpp/datasets/ifeval.h @@ -62,8 +62,7 @@ class IFEval : public Dataset { bool HasAccuracy() override; - bool ComputeSampleAccuracy(const int sample_idx, - ifeval::Accuracy& accuracy); + bool ComputeSampleAccuracy(const int sample_idx, ifeval::Accuracy& accuracy); float ComputeAccuracy() override; diff --git a/flutter/cpp/datasets/ifeval_utils/types.h b/flutter/cpp/datasets/ifeval_utils/types.h index e90f8c8a7..b17e7f8ef 100644 --- a/flutter/cpp/datasets/ifeval_utils/types.h +++ b/flutter/cpp/datasets/ifeval_utils/types.h @@ -437,9 +437,20 @@ class Frequency : public Instruction { static inline std::string RegexEscape(const std::string& s) { auto is_meta = [](unsigned char ch) { switch (ch) { - case '^': case '$': case '.': case '|': case '?': - case '*': case '+': case '(': case ')': - case '[': case ']': case '{': case '}': case '\\': + case '^': + case '$': + case '.': + case '|': + case '?': + case '*': + case '+': + case '(': + case ')': + case '[': + case ']': + case '{': + case '}': + case '\\': return true; default: return false; diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h index 8f55f58f0..f7552cd4d 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.h @@ -16,11 +16,12 @@ limitations under the License. #ifndef TFLITE_LLM_PIPELINE_H_ #define TFLITE_LLM_PIPELINE_H_ +#include + #include #include #include #include -#include #if defined(_MSC_VER) #include From 15880a9b5469529049a8771a4a6f3a7febb1bf88 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 07:07:09 +0300 Subject: [PATCH 42/74] all possible configs for removing eigen exceptions --- .bazelrc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.bazelrc b/.bazelrc index 8b238a704..390ab22cd 100644 --- a/.bazelrc +++ b/.bazelrc @@ -81,6 +81,14 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc build:ios --copt=-DEIGEN_EXCEPTIONS=0 build:ios --copt=-DEIGEN_NO_EXCEPTIONS +build:ios --cxxopt=-DEIGEN_EXCEPTIONS=0 +build:ios --cxxopt=-DEIGEN_NO_EXCEPTIONS +build:ios --cxxopt=-fno-exceptions +build:ios --objcxxopt=-DEIGEN_EXCEPTIONS=0 +build:ios --objcxxopt=-DEIGEN_NO_EXCEPTIONS +build:ios --objcxxopt=-fno-exceptions +build:ios --objcopt=-fno-exceptions +build:ios --conlyopt=-fno-exceptions # Windows configs From 5ddbb87ba83fcefc82c702a2eb9acde595a1614d Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 07:22:47 +0300 Subject: [PATCH 43/74] removed objc opts --- .bazelrc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 390ab22cd..80fe14b76 100644 --- a/.bazelrc +++ b/.bazelrc @@ -81,14 +81,10 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc build:ios --copt=-DEIGEN_EXCEPTIONS=0 build:ios --copt=-DEIGEN_NO_EXCEPTIONS +build:ios --copt=-fno-exceptions build:ios --cxxopt=-DEIGEN_EXCEPTIONS=0 build:ios --cxxopt=-DEIGEN_NO_EXCEPTIONS build:ios --cxxopt=-fno-exceptions -build:ios --objcxxopt=-DEIGEN_EXCEPTIONS=0 -build:ios --objcxxopt=-DEIGEN_NO_EXCEPTIONS -build:ios --objcxxopt=-fno-exceptions -build:ios --objcopt=-fno-exceptions -build:ios --conlyopt=-fno-exceptions # Windows configs From 7a4042a4121b584880921b114f95163df886008b Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Tue, 28 Oct 2025 23:57:03 +0300 Subject: [PATCH 44/74] use token latencies in app --- flutter/cpp/flutter/dart_run_benchmark.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flutter/cpp/flutter/dart_run_benchmark.cc b/flutter/cpp/flutter/dart_run_benchmark.cc index 96df91ee3..77ce28f25 100644 --- a/flutter/cpp/flutter/dart_run_benchmark.cc +++ b/flutter/cpp/flutter/dart_run_benchmark.cc @@ -76,6 +76,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( ::std::unique_ptr<::mlperf::mobile::Dataset> dataset; std::string sp_path; std::string sp_path_filename; + bool use_token_latencies = false; switch (in->dataset_type) { case ::mlperf::mobile::DatasetConfig::IMAGENET: dataset = std::make_unique<::mlperf::mobile::Imagenet>( @@ -116,6 +117,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( sp_path = in->backend_model_path; sp_path += '/' + sp_path_filename; LOG(INFO) << "SP path: " << sp_path; + use_token_latencies = true; dataset = std::make_unique<::mlperf::mobile::MmluGen>( backend.get(), in->dataset_data_path, sp_path, true /*zero-shot*/); break; @@ -138,7 +140,7 @@ struct dart_ffi_run_benchmark_out* dart_ffi_run_benchmark( auto start = std::chrono::steady_clock::now(); driver.RunMLPerfTest(in->mode, in->min_query_count, in->min_duration, in->max_duration, in->single_stream_expected_latency_ns, - in->output_dir); + in->output_dir, use_token_latencies); auto end = std::chrono::steady_clock::now(); li; From cfc719b4476fb41d56d8ed1bfdb76c1dcf9822ea Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Wed, 29 Oct 2025 00:00:47 +0300 Subject: [PATCH 45/74] enable exceptions for IOS --- .bazelrc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index 80fe14b76..7967bf329 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,12 +79,8 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc -build:ios --copt=-DEIGEN_EXCEPTIONS=0 -build:ios --copt=-DEIGEN_NO_EXCEPTIONS -build:ios --copt=-fno-exceptions -build:ios --cxxopt=-DEIGEN_EXCEPTIONS=0 -build:ios --cxxopt=-DEIGEN_NO_EXCEPTIONS -build:ios --cxxopt=-fno-exceptions +build:ios --copt=-fexceptions +build:ios --cxxopt=-fexceptions # Windows configs From f87ef860cc932f2d95feed7984f405b920db9f2d Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Wed, 29 Oct 2025 00:34:18 +0300 Subject: [PATCH 46/74] disable FP16 AVX for x86 simulator --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 7967bf329..057b23aa0 100644 --- a/.bazelrc +++ b/.bazelrc @@ -81,6 +81,8 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc build:ios --copt=-fexceptions build:ios --cxxopt=-fexceptions +# disable avx512-fp16 for x86 simulator +build:ios --define=xnn_enable_avx512fp16=false # Windows configs From b639dec6cdea7ed3b3b65d3e3d9db29031bbdb54 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 03:47:09 +0300 Subject: [PATCH 47/74] attempt to enable exceptions for eigen --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 057b23aa0..8a41fd3c4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,6 +79,8 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc +build:ios --per_file_copt=external/eigen_archive/.*@-fexceptions +build:ios --per_file_copt=external/local_xla/.*@-fexceptions build:ios --copt=-fexceptions build:ios --cxxopt=-fexceptions # disable avx512-fp16 for x86 simulator From 3919626587ca0aac981fb6c6235f41399262bbaa Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 04:03:38 +0300 Subject: [PATCH 48/74] 2nd attempt at enabling exceptions for IOS eigen --- .bazelrc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index 8a41fd3c4..59180266e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,8 +79,9 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc -build:ios --per_file_copt=external/eigen_archive/.*@-fexceptions -build:ios --per_file_copt=external/local_xla/.*@-fexceptions +build:ios --per_file_copt=external/eigen_archive/.*@-fexceptions@-fcxx-exceptions +build:ios --per_file_copt=external/local_xla/.*@-fexceptions@-fcxx-exceptions +build:ios --per_file_copt=external/local_tsl/.*@-fexceptions@-fcxx-exceptions build:ios --copt=-fexceptions build:ios --cxxopt=-fexceptions # disable avx512-fp16 for x86 simulator From bb31d354ef912aa4bc859edbf55c6c6880d74634 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 04:23:19 +0300 Subject: [PATCH 49/74] fixed fexceptions syntax --- .bazelrc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index 59180266e..550473d66 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,9 +79,13 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc -build:ios --per_file_copt=external/eigen_archive/.*@-fexceptions@-fcxx-exceptions -build:ios --per_file_copt=external/local_xla/.*@-fexceptions@-fcxx-exceptions -build:ios --per_file_copt=external/local_tsl/.*@-fexceptions@-fcxx-exceptions +# Ensure C++ exceptions ON for Eigen + TF/TSL/XLA files that include Eigen. +build:ios --per_file_copt=external/eigen_archive/.*@-fexceptions +build:ios --per_file_copt=external/eigen_archive/.*@-fcxx-exceptions +build:ios --per_file_copt=external/local_xla/.*@-fexceptions +build:ios --per_file_copt=external/local_xla/.*@-fcxx-exceptions +build:ios --per_file_copt=external/local_tsl/.*@-fexceptions +build:ios --per_file_copt=external/local_tsl/.*@-fcxx-exceptions build:ios --copt=-fexceptions build:ios --cxxopt=-fexceptions # disable avx512-fp16 for x86 simulator From d07172ee1e7cc57f938b3935bd862b23bf24906d Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 04:51:29 +0300 Subject: [PATCH 50/74] kitchen-sink approach to enable exceptions for IOS --- .bazelrc | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.bazelrc b/.bazelrc index 550473d66..722c65ab5 100644 --- a/.bazelrc +++ b/.bazelrc @@ -80,12 +80,8 @@ build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc # Ensure C++ exceptions ON for Eigen + TF/TSL/XLA files that include Eigen. -build:ios --per_file_copt=external/eigen_archive/.*@-fexceptions -build:ios --per_file_copt=external/eigen_archive/.*@-fcxx-exceptions -build:ios --per_file_copt=external/local_xla/.*@-fexceptions -build:ios --per_file_copt=external/local_xla/.*@-fcxx-exceptions -build:ios --per_file_copt=external/local_tsl/.*@-fexceptions -build:ios --per_file_copt=external/local_tsl/.*@-fcxx-exceptions +build:ios --per_file_copt=external/.*@-fexceptions +build:ios --per_file_copt=external/.*@-fcxx-exceptions build:ios --copt=-fexceptions build:ios --cxxopt=-fexceptions # disable avx512-fp16 for x86 simulator From fe90222f194e3ae25584903f3e6a93e01f8e90f2 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 05:44:09 +0300 Subject: [PATCH 51/74] attempt to undefine EIGEN_EXCEPTIONS for IOS --- .bazelrc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.bazelrc b/.bazelrc index 722c65ab5..5a59d800a 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,11 +79,8 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc -# Ensure C++ exceptions ON for Eigen + TF/TSL/XLA files that include Eigen. -build:ios --per_file_copt=external/.*@-fexceptions -build:ios --per_file_copt=external/.*@-fcxx-exceptions -build:ios --copt=-fexceptions -build:ios --cxxopt=-fexceptions +build:ios --copt=-UEIGEN_EXCEPTIONS +build:ios --cxxopt=-UEIGEN_EXCEPTIONS # disable avx512-fp16 for x86 simulator build:ios --define=xnn_enable_avx512fp16=false From 9eee64839d585f02b3b20a17cd45ab056f5055ea Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 06:17:04 +0300 Subject: [PATCH 52/74] add global ovverride to disable eigen exceptions --- .bazelrc | 4 ++-- flutter/third_party/tf-eigen.patch | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index 5a59d800a..80af0ec7e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -79,8 +79,8 @@ build:linux_x86_64 --define=xnn_enable_avxvnniint8=false build:ios --apple_platform_type=ios build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc -build:ios --copt=-UEIGEN_EXCEPTIONS -build:ios --cxxopt=-UEIGEN_EXCEPTIONS +build:ios --copt=-DEIGEN_NOEXCEPTIONS_OVERRIDE +build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE # disable avx512-fp16 for x86 simulator build:ios --define=xnn_enable_avx512fp16=false diff --git a/flutter/third_party/tf-eigen.patch b/flutter/third_party/tf-eigen.patch index 12c6d3eef..8c84d2ff1 100644 --- a/flutter/third_party/tf-eigen.patch +++ b/flutter/third_party/tf-eigen.patch @@ -24,7 +24,7 @@ index 00000000000..6c94319de42 + + +-#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE) -++#if (defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE) +++#if !defined(EIGEN_NOEXCEPTIONS_OVERRIDE) && ((defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)) + #define EIGEN_EXCEPTIONS + #endif + From d9466a84f8c327e9b1c1e198ee3ca7b1f352e3e8 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 06:41:51 +0300 Subject: [PATCH 53/74] use ARM based macos for IOS build --- .github/workflows/ios-build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ios-build-test.yml b/.github/workflows/ios-build-test.yml index f91fe82b9..58c6c26fc 100644 --- a/.github/workflows/ios-build-test.yml +++ b/.github/workflows/ios-build-test.yml @@ -10,7 +10,7 @@ jobs: build: name: Build and test iOS app # https://github.com/actions/runner-images/blob/main/images/macos/macos-12-Readme.md - runs-on: macos-13 + runs-on: macos-14 timeout-minutes: 180 env: PERF_TEST: true From 6c1322ed1cefba1d08363acfe278e2d914af1164 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 07:11:05 +0300 Subject: [PATCH 54/74] fixed and re-enabled eigen patch --- WORKSPACE | 2 +- flutter/third_party/tf-eigen.patch | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index b082a1c23..d422dddb5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -58,7 +58,7 @@ http_archive( "//:flutter/third_party/use_unsigned_char.patch", # Fix tensorflow not being able to read image files on Windows "//:flutter/third_party/tensorflow-fix-file-opening-mode-for-Windows.patch", - #"//:flutter/third_party/tf-eigen.patch", + "//:flutter/third_party/tf-eigen.patch", ] + PATCH_FILE, sha256 = "d7876f4bb0235cac60eb6316392a7c48676729860da1ab659fb440379ad5186d", strip_prefix = "tensorflow-2.18.0", diff --git a/flutter/third_party/tf-eigen.patch b/flutter/third_party/tf-eigen.patch index 8c84d2ff1..d50bc3017 100644 --- a/flutter/third_party/tf-eigen.patch +++ b/flutter/third_party/tf-eigen.patch @@ -19,7 +19,7 @@ index 00000000000..6c94319de42 +index e76ddd3d2..6b4fc84ec 100644 +--- a/Eigen/src/Core/util/Macros.h ++++ b/Eigen/src/Core/util/Macros.h -+@@ -1281,7 +1281,7 @@ namespace Eigen { ++@@ -1236,7 +1236,7 @@ namespace Eigen { + EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) + + From 5667a6a184f3814bf7334d8f01729bfe0cbf378a Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 07:24:29 +0300 Subject: [PATCH 55/74] further fix for eigen patch --- flutter/third_party/tf-eigen.patch | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flutter/third_party/tf-eigen.patch b/flutter/third_party/tf-eigen.patch index d50bc3017..97b0ffae1 100644 --- a/flutter/third_party/tf-eigen.patch +++ b/flutter/third_party/tf-eigen.patch @@ -19,13 +19,13 @@ index 00000000000..6c94319de42 +index e76ddd3d2..6b4fc84ec 100644 +--- a/Eigen/src/Core/util/Macros.h ++++ b/Eigen/src/Core/util/Macros.h -+@@ -1236,7 +1236,7 @@ namespace Eigen { ++@@ -1236,7 +1236,6 @@ namespace Eigen { + EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) -+ -+ -+-#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE) ++ ++-#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && \ ++- !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE) ++#if !defined(EIGEN_NOEXCEPTIONS_OVERRIDE) && ((defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)) -+ #define EIGEN_EXCEPTIONS ++ #define EIGEN_EXCEPTIONS + #endif + diff --git a/third_party/eigen3/workspace.bzl b/third_party/eigen3/workspace.bzl From b927a505ef9ea3d6dda10799e70597943a87746e Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 07:44:33 +0300 Subject: [PATCH 56/74] even more patch fixing --- flutter/third_party/tf-eigen.patch | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/flutter/third_party/tf-eigen.patch b/flutter/third_party/tf-eigen.patch index 97b0ffae1..dca5e906b 100644 --- a/flutter/third_party/tf-eigen.patch +++ b/flutter/third_party/tf-eigen.patch @@ -15,19 +15,18 @@ index 00000000000..6c94319de42 --- /dev/null +++ b/third_party/eigen3/eigen_ios.patch @@ -0,0 +1,13 @@ -+diff --git a/Eigen/src/Core/util/Macros.h b/Eigen/src/Core/util/Macros.h -+index e76ddd3d2..6b4fc84ec 100644 +--- a/Eigen/src/Core/util/Macros.h ++++ b/Eigen/src/Core/util/Macros.h -+@@ -1236,7 +1236,6 @@ namespace Eigen { -+ EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD,OPNAME) -+ ++@@ -1235,8 +1235,7 @@ ++ EIGEN_MAKE_SCALAR_BINARY_OP_ONTHELEFT(METHOD, OPNAME) \ ++ EIGEN_MAKE_SCALAR_BINARY_OP_ONTHERIGHT(METHOD, OPNAME) ++ +-#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && \ +- !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE) -++#if !defined(EIGEN_NOEXCEPTIONS_OVERRIDE) && ((defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)) +++#if !defined(EIGEN_NOEXCEPTIONS_OVERRIDE) && ((defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !+defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)) + #define EIGEN_EXCEPTIONS + #endif -+ ++ diff --git a/third_party/eigen3/workspace.bzl b/third_party/eigen3/workspace.bzl index 9782907cf5e..ad302dbabc5 100644 --- a/third_party/eigen3/workspace.bzl From a828de78acf981777dfaec190093cf1627efd9a9 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 07:52:30 +0300 Subject: [PATCH 57/74] fixed typo in eigen patch --- flutter/third_party/tf-eigen.patch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flutter/third_party/tf-eigen.patch b/flutter/third_party/tf-eigen.patch index dca5e906b..9571c8ce0 100644 --- a/flutter/third_party/tf-eigen.patch +++ b/flutter/third_party/tf-eigen.patch @@ -23,7 +23,7 @@ index 00000000000..6c94319de42 + +-#if (defined(_CPPUNWIND) || defined(__EXCEPTIONS)) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && \ +- !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE) -++#if !defined(EIGEN_NOEXCEPTIONS_OVERRIDE) && ((defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !+defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)) +++#if !defined(EIGEN_NOEXCEPTIONS_OVERRIDE) && ((defined(_CPPUNWIND) || (defined(__EXCEPTIONS) && defined(__exceptions__))) && !defined(EIGEN_CUDA_ARCH) && !defined(EIGEN_EXCEPTIONS) && !defined(EIGEN_USE_SYCL) && !defined(EIGEN_HIP_DEVICE_COMPILE)) + #define EIGEN_EXCEPTIONS + #endif + From 867534ebc4529daab6c1b66f1fe903c0e0e77a70 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 07:54:19 +0300 Subject: [PATCH 58/74] fixed incorrect count in eigen patch --- flutter/third_party/tf-eigen.patch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flutter/third_party/tf-eigen.patch b/flutter/third_party/tf-eigen.patch index 9571c8ce0..35658c039 100644 --- a/flutter/third_party/tf-eigen.patch +++ b/flutter/third_party/tf-eigen.patch @@ -14,7 +14,7 @@ new file mode 100644 index 00000000000..6c94319de42 --- /dev/null +++ b/third_party/eigen3/eigen_ios.patch -@@ -0,0 +1,13 @@ +@@ -0,0 +1,12 @@ +--- a/Eigen/src/Core/util/Macros.h ++++ b/Eigen/src/Core/util/Macros.h +@@ -1235,8 +1235,7 @@ From 810d26054e5af5322a1b3de910131a3bd0b66579 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 08:28:23 +0300 Subject: [PATCH 59/74] force arm64 ios build --- .bazelrc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.bazelrc b/.bazelrc index 80af0ec7e..d4ff0b8b4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -83,6 +83,8 @@ build:ios --copt=-DEIGEN_NOEXCEPTIONS_OVERRIDE build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE # disable avx512-fp16 for x86 simulator build:ios --define=xnn_enable_avx512fp16=false +# force arm64 simulator +build:ios --ios_multi_cpus=sim_arm64 # Windows configs From 8878a509a03a1fb9aa95499651acec4aa54936fc Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 09:49:27 +0300 Subject: [PATCH 60/74] use ARM64 simulator for IOS build --- flutter/cpp/flutter/BUILD | 4 ++-- mobile_back_apple/cpp/backend_coreml/BUILD | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flutter/cpp/flutter/BUILD b/flutter/cpp/flutter/BUILD index d6bc26a16..a68811070 100644 --- a/flutter/cpp/flutter/BUILD +++ b/flutter/cpp/flutter/BUILD @@ -48,9 +48,9 @@ apple_xcframework( infoplists = ["//flutter/cpp/flutter:BackendBridgeInfo.plist"], ios = { "simulator": [ - "x86_64", + #"x86_64", # cpuinfo does not support simulator on ARM-based macs - # "ios_sim_arm64", + "ios_sim_arm64", ], "device": ["arm64"], }, diff --git a/mobile_back_apple/cpp/backend_coreml/BUILD b/mobile_back_apple/cpp/backend_coreml/BUILD index 4e8acf07d..a3f2b3099 100644 --- a/mobile_back_apple/cpp/backend_coreml/BUILD +++ b/mobile_back_apple/cpp/backend_coreml/BUILD @@ -31,9 +31,9 @@ apple_xcframework( infoplists = ["//flutter/cpp/flutter:BackendBridgeInfo.plist"], ios = { "simulator": [ - "x86_64", + #"x86_64", # cpuinfo doesn't support simulator on ARM-based macs - # "ios_sim_arm64", + "ios_sim_arm64", ], "device": ["arm64"], }, From 744259dc681182dd5b4a7dd6a721c7026ba7a6c7 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 09:50:52 +0300 Subject: [PATCH 61/74] use arm64 simulator for tflite on IOS --- mobile_back_tflite/cpp/backend_tflite/ios/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mobile_back_tflite/cpp/backend_tflite/ios/BUILD b/mobile_back_tflite/cpp/backend_tflite/ios/BUILD index 74fa88aea..7a3d8b350 100644 --- a/mobile_back_tflite/cpp/backend_tflite/ios/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/ios/BUILD @@ -7,9 +7,9 @@ apple_xcframework( infoplists = ["//flutter/cpp/flutter:BackendBridgeInfo.plist"], ios = { "simulator": [ - "x86_64", + #"x86_64", # cpuinfo doesn't support simulator on ARM-based macs - # "ios_sim_arm64", + "ios_sim_arm64", ], "device": ["arm64"], }, From 94cf6d64b8be9f431d920d7f6674d84a3ccb577f Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 10:02:31 +0300 Subject: [PATCH 62/74] set ios cpu argument for cpuinfo --- .bazelrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.bazelrc b/.bazelrc index d4ff0b8b4..d6cef30eb 100644 --- a/.bazelrc +++ b/.bazelrc @@ -85,6 +85,7 @@ build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE build:ios --define=xnn_enable_avx512fp16=false # force arm64 simulator build:ios --ios_multi_cpus=sim_arm64 +build:ios --cpu=ios_sim_arm64 # Windows configs From 85c8b2d62c50bbaf139f31b3221d2b24517494d9 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 17:38:13 +0300 Subject: [PATCH 63/74] remvoed ios_sim prefix --- flutter/cpp/flutter/BUILD | 2 +- mobile_back_apple/cpp/backend_coreml/BUILD | 2 +- mobile_back_tflite/cpp/backend_tflite/ios/BUILD | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flutter/cpp/flutter/BUILD b/flutter/cpp/flutter/BUILD index a68811070..01be3ec10 100644 --- a/flutter/cpp/flutter/BUILD +++ b/flutter/cpp/flutter/BUILD @@ -50,7 +50,7 @@ apple_xcframework( "simulator": [ #"x86_64", # cpuinfo does not support simulator on ARM-based macs - "ios_sim_arm64", + "arm64", ], "device": ["arm64"], }, diff --git a/mobile_back_apple/cpp/backend_coreml/BUILD b/mobile_back_apple/cpp/backend_coreml/BUILD index a3f2b3099..9d5cbbf4a 100644 --- a/mobile_back_apple/cpp/backend_coreml/BUILD +++ b/mobile_back_apple/cpp/backend_coreml/BUILD @@ -33,7 +33,7 @@ apple_xcframework( "simulator": [ #"x86_64", # cpuinfo doesn't support simulator on ARM-based macs - "ios_sim_arm64", + "arm64", ], "device": ["arm64"], }, diff --git a/mobile_back_tflite/cpp/backend_tflite/ios/BUILD b/mobile_back_tflite/cpp/backend_tflite/ios/BUILD index 7a3d8b350..fe85022b9 100644 --- a/mobile_back_tflite/cpp/backend_tflite/ios/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/ios/BUILD @@ -9,7 +9,7 @@ apple_xcframework( "simulator": [ #"x86_64", # cpuinfo doesn't support simulator on ARM-based macs - "ios_sim_arm64", + "arm64", ], "device": ["arm64"], }, From bffdcd51290e325595d04d6e1ba7c3305359a208 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 19:24:53 +0300 Subject: [PATCH 64/74] attempt at using arm64 simulator for IOS instead of x86 --- .bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index d6cef30eb..706e62bb6 100644 --- a/.bazelrc +++ b/.bazelrc @@ -84,7 +84,7 @@ build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE # disable avx512-fp16 for x86 simulator build:ios --define=xnn_enable_avx512fp16=false # force arm64 simulator -build:ios --ios_multi_cpus=sim_arm64 +build:ios --ios_multi_cpus=arm64 build:ios --cpu=ios_sim_arm64 # Windows configs From 7b7f30d1fdd6597a5563f938f931d716413562fe Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Thu, 30 Oct 2025 23:20:17 +0300 Subject: [PATCH 65/74] attempt to force flutter to build ITs for arm64 only --- .github/workflows/ios-build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ios-build-test.yml b/.github/workflows/ios-build-test.yml index 41d091ba1..38caf8147 100644 --- a/.github/workflows/ios-build-test.yml +++ b/.github/workflows/ios-build-test.yml @@ -61,7 +61,7 @@ jobs: make flutter/test/unit - name: Build iOS integration tests run: | - cd flutter && flutter --no-version-check build ios --simulator integration_test/first_test.dart + cd flutter && arch -arm64 flutter --no-version-check build ios --simulator integration_test/first_test.dart - name: Setup iOS simulator env: DEVICE_NAME: "iPhone 16 Pro" From f755765652b85433c17e264e826a7c336f0c9efb Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 00:22:24 +0300 Subject: [PATCH 66/74] force arm64 for pods --- flutter/ios/Podfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flutter/ios/Podfile b/flutter/ios/Podfile index 4920ff15b..d7569d074 100644 --- a/flutter/ios/Podfile +++ b/flutter/ios/Podfile @@ -67,7 +67,8 @@ post_install do |installer| end installer.pods_project.targets.each do |target| target.build_configurations.each do |config| - config.build_settings['EXCLUDED_ARCHS[sdk=iphonesimulator*]'] = 'arm64' + config.build_settings['EXCLUDED_ARCHS[sdk=iphonesimulator*]'] = '' + config.build_settings['ARCHS'] = 'arm64' end end end From 1e7d5a088ce24e805284b00fc9da0e05bc9190cd Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 01:29:19 +0300 Subject: [PATCH 67/74] disable f16 instead of building for arm64 --- .bazelrc | 5 ++--- .github/workflows/ios-build-test.yml | 2 +- flutter/cpp/flutter/BUILD | 4 ++-- flutter/ios/Podfile | 3 +-- mobile_back_apple/cpp/backend_coreml/BUILD | 4 ++-- mobile_back_tflite/cpp/backend_tflite/ios/BUILD | 4 ++-- 6 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.bazelrc b/.bazelrc index 706e62bb6..ac54c7758 100644 --- a/.bazelrc +++ b/.bazelrc @@ -82,10 +82,9 @@ build:ios --cxxopt=-fobjc-arc build:ios --copt=-DEIGEN_NOEXCEPTIONS_OVERRIDE build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE # disable avx512-fp16 for x86 simulator +build:ios --define=xnn_enable_f16=false +build:ios --define=xnn_enable_f16c=false build:ios --define=xnn_enable_avx512fp16=false -# force arm64 simulator -build:ios --ios_multi_cpus=arm64 -build:ios --cpu=ios_sim_arm64 # Windows configs diff --git a/.github/workflows/ios-build-test.yml b/.github/workflows/ios-build-test.yml index 38caf8147..41d091ba1 100644 --- a/.github/workflows/ios-build-test.yml +++ b/.github/workflows/ios-build-test.yml @@ -61,7 +61,7 @@ jobs: make flutter/test/unit - name: Build iOS integration tests run: | - cd flutter && arch -arm64 flutter --no-version-check build ios --simulator integration_test/first_test.dart + cd flutter && flutter --no-version-check build ios --simulator integration_test/first_test.dart - name: Setup iOS simulator env: DEVICE_NAME: "iPhone 16 Pro" diff --git a/flutter/cpp/flutter/BUILD b/flutter/cpp/flutter/BUILD index 01be3ec10..19b9616b2 100644 --- a/flutter/cpp/flutter/BUILD +++ b/flutter/cpp/flutter/BUILD @@ -48,9 +48,9 @@ apple_xcframework( infoplists = ["//flutter/cpp/flutter:BackendBridgeInfo.plist"], ios = { "simulator": [ - #"x86_64", + "x86_64", # cpuinfo does not support simulator on ARM-based macs - "arm64", + # "arm64", ], "device": ["arm64"], }, diff --git a/flutter/ios/Podfile b/flutter/ios/Podfile index d7569d074..4920ff15b 100644 --- a/flutter/ios/Podfile +++ b/flutter/ios/Podfile @@ -67,8 +67,7 @@ post_install do |installer| end installer.pods_project.targets.each do |target| target.build_configurations.each do |config| - config.build_settings['EXCLUDED_ARCHS[sdk=iphonesimulator*]'] = '' - config.build_settings['ARCHS'] = 'arm64' + config.build_settings['EXCLUDED_ARCHS[sdk=iphonesimulator*]'] = 'arm64' end end end diff --git a/mobile_back_apple/cpp/backend_coreml/BUILD b/mobile_back_apple/cpp/backend_coreml/BUILD index 9d5cbbf4a..2caebeaa3 100644 --- a/mobile_back_apple/cpp/backend_coreml/BUILD +++ b/mobile_back_apple/cpp/backend_coreml/BUILD @@ -31,9 +31,9 @@ apple_xcframework( infoplists = ["//flutter/cpp/flutter:BackendBridgeInfo.plist"], ios = { "simulator": [ - #"x86_64", + "x86_64", # cpuinfo doesn't support simulator on ARM-based macs - "arm64", + # "arm64", ], "device": ["arm64"], }, diff --git a/mobile_back_tflite/cpp/backend_tflite/ios/BUILD b/mobile_back_tflite/cpp/backend_tflite/ios/BUILD index fe85022b9..0e7eb1452 100644 --- a/mobile_back_tflite/cpp/backend_tflite/ios/BUILD +++ b/mobile_back_tflite/cpp/backend_tflite/ios/BUILD @@ -7,9 +7,9 @@ apple_xcframework( infoplists = ["//flutter/cpp/flutter:BackendBridgeInfo.plist"], ios = { "simulator": [ - #"x86_64", + "x86_64", # cpuinfo doesn't support simulator on ARM-based macs - "arm64", + # "arm64", ], "device": ["arm64"], }, From 0482013f02ebe5fc0a55abcef0bf63c844fa600d Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 02:19:46 +0300 Subject: [PATCH 68/74] more bazel config lines to disable fp16 --- .bazelrc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index ac54c7758..c64957df3 100644 --- a/.bazelrc +++ b/.bazelrc @@ -81,10 +81,16 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc build:ios --copt=-DEIGEN_NOEXCEPTIONS_OVERRIDE build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE -# disable avx512-fp16 for x86 simulator +# TODO these should only be enabled for x86 simulator build:ios --define=xnn_enable_f16=false build:ios --define=xnn_enable_f16c=false build:ios --define=xnn_enable_avx512fp16=false +build:ios --copt=-DFP16_DISABLE=1 +build:ios --copt=-DPSIMD_DISABLE_F16=1 +build:ios --copt=-mno-f16c +build:ios --cxxopt=-DFP16_DISABLE=1 +build:ios --cxxopt=-DPSIMD_DISABLE_F16=1 +build:ios --cxxopt=-mno-f16c # Windows configs From 5dd0383d7803f692fc8cb139409a0f5dfcf64f6e Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 03:46:55 +0300 Subject: [PATCH 69/74] removed unavailable compiler flags --- .bazelrc | 2 -- 1 file changed, 2 deletions(-) diff --git a/.bazelrc b/.bazelrc index c64957df3..6f7200c0f 100644 --- a/.bazelrc +++ b/.bazelrc @@ -87,10 +87,8 @@ build:ios --define=xnn_enable_f16c=false build:ios --define=xnn_enable_avx512fp16=false build:ios --copt=-DFP16_DISABLE=1 build:ios --copt=-DPSIMD_DISABLE_F16=1 -build:ios --copt=-mno-f16c build:ios --cxxopt=-DFP16_DISABLE=1 build:ios --cxxopt=-DPSIMD_DISABLE_F16=1 -build:ios --cxxopt=-mno-f16c # Windows configs From 142f3061256f4d433396e94f1f27042297315b59 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 07:22:26 +0300 Subject: [PATCH 70/74] provide patched fp16 lib with math workaround --- .bazelrc | 8 ------ WORKSPACE | 11 +++++++ patches/fp16_math_workaround.patch | 46 ++++++++++++++++++++++++++++++ third_party/FP16.BUILD | 15 ++++++++++ 4 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 patches/fp16_math_workaround.patch create mode 100644 third_party/FP16.BUILD diff --git a/.bazelrc b/.bazelrc index 6f7200c0f..bab46822c 100644 --- a/.bazelrc +++ b/.bazelrc @@ -81,14 +81,6 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --cxxopt=-fobjc-arc build:ios --copt=-DEIGEN_NOEXCEPTIONS_OVERRIDE build:ios --cxxopt=-DEIGEN_NOEXCEPTIONS_OVERRIDE -# TODO these should only be enabled for x86 simulator -build:ios --define=xnn_enable_f16=false -build:ios --define=xnn_enable_f16c=false -build:ios --define=xnn_enable_avx512fp16=false -build:ios --copt=-DFP16_DISABLE=1 -build:ios --copt=-DPSIMD_DISABLE_F16=1 -build:ios --cxxopt=-DFP16_DISABLE=1 -build:ios --cxxopt=-DPSIMD_DISABLE_F16=1 # Windows configs diff --git a/WORKSPACE b/WORKSPACE index d422dddb5..01a5ff312 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -48,6 +48,17 @@ http_archive( load("@tf_patch_finder//:patch_win_arm64.bzl", "PATCH_FILE") +http_archive( + name = "FP16", + build_file = "@//third_party:FP16.BUILD", + patches = ["//patches:fp16_math_workaround.patch.diff"], + sha256 = "e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70", + strip_prefix = "FP16-0a92994d729ff76a58f692d3028ca1b64b145d91", + urls = [ + "https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip", + ], +) + http_archive( name = "org_tensorflow", patch_args = ["-p1"], diff --git a/patches/fp16_math_workaround.patch b/patches/fp16_math_workaround.patch new file mode 100644 index 000000000..3a50f6be8 --- /dev/null +++ b/patches/fp16_math_workaround.patch @@ -0,0 +1,46 @@ +From 6bc45b3b372bfc67a514d0bda24993e983f79aa8 Mon Sep 17 00:00:00 2001 +From: Dillon +Date: Sat, 31 Aug 2024 13:18:10 -0700 +Subject: [PATCH] Remove dependency on math.h. This of course should be fine, + but there are some misconfigured cross compilation builds that have issues + with math.h. These builds should be fixed, but I also just generally like to + avoid dependencies where possible, which it is in this case. + +--- + include/fp16/fp16.h | 9 ++++----- + 1 file changed, 4 insertions(+), 5 deletions(-) + +diff --git a/include/fp16/fp16.h b/include/fp16/fp16.h +index 95fa0c2..142ef6a 100644 +--- a/include/fp16/fp16.h ++++ b/include/fp16/fp16.h +@@ -4,10 +4,8 @@ + + #if defined(__cplusplus) && (__cplusplus >= 201103L) + #include +- #include + #elif !defined(__OPENCL_VERSION__) + #include +- #include + #endif + + #include +@@ -286,14 +284,15 @@ static inline uint16_t fp16_ieee_from_fp32_value(float f) { + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); + #endif ++ const uint32_t w = fp32_to_bits(f); ++ const float abs_f = fp32_from_bits(w & UINT32_C(0x7FFFFFFF)); + #if defined(_MSC_VER) && defined(_M_IX86_FP) && (_M_IX86_FP == 0) || defined(__GNUC__) && defined(__FLT_EVAL_METHOD__) && (__FLT_EVAL_METHOD__ != 0) +- const volatile float saturated_f = fabsf(f) * scale_to_inf; ++ const volatile float saturated_f = abs_f * scale_to_inf; + #else +- const float saturated_f = fabsf(f) * scale_to_inf; ++ const float saturated_f = abs_f * scale_to_inf; + #endif + float base = saturated_f * scale_to_zero; + +- const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); diff --git a/third_party/FP16.BUILD b/third_party/FP16.BUILD new file mode 100644 index 000000000..e1018beb4 --- /dev/null +++ b/third_party/FP16.BUILD @@ -0,0 +1,15 @@ +# Description: +# C/C++ library for conversion to/from half-precision floating-point formats + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "FP16", + hdrs = glob(["include/**/*.h"]), + includes = ["include"], + strip_include_prefix = "include", +) From af80d6781ad8cbe979956ec02de8e953a75f3e2e Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 07:31:37 +0300 Subject: [PATCH 71/74] typo --- WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/WORKSPACE b/WORKSPACE index 01a5ff312..996c39ca7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -51,7 +51,7 @@ load("@tf_patch_finder//:patch_win_arm64.bzl", "PATCH_FILE") http_archive( name = "FP16", build_file = "@//third_party:FP16.BUILD", - patches = ["//patches:fp16_math_workaround.patch.diff"], + patches = ["//patches:fp16_math_workaround.patch"], sha256 = "e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70", strip_prefix = "FP16-0a92994d729ff76a58f692d3028ca1b64b145d91", urls = [ From 674ffde84fcc172d42f4d6ee0939271ed25e452e Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 07:44:58 +0300 Subject: [PATCH 72/74] added patch arg --- WORKSPACE | 1 + 1 file changed, 1 insertion(+) diff --git a/WORKSPACE b/WORKSPACE index 996c39ca7..68a71d32d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -51,6 +51,7 @@ load("@tf_patch_finder//:patch_win_arm64.bzl", "PATCH_FILE") http_archive( name = "FP16", build_file = "@//third_party:FP16.BUILD", + patch_args = ["-p1"], patches = ["//patches:fp16_math_workaround.patch"], sha256 = "e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70", strip_prefix = "FP16-0a92994d729ff76a58f692d3028ca1b64b145d91", From 36f46024d6273437c69feae8c5831a8a9e714104 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 08:08:57 +0300 Subject: [PATCH 73/74] created a math workaround patch compatible with fp16 version used by xnnpack --- patches/fp16_math_workaround.patch | 34 ++++++++++-------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/patches/fp16_math_workaround.patch b/patches/fp16_math_workaround.patch index 3a50f6be8..d44fb03ec 100644 --- a/patches/fp16_math_workaround.patch +++ b/patches/fp16_math_workaround.patch @@ -1,19 +1,13 @@ -From 6bc45b3b372bfc67a514d0bda24993e983f79aa8 Mon Sep 17 00:00:00 2001 -From: Dillon -Date: Sat, 31 Aug 2024 13:18:10 -0700 +Based on 6bc45b3b372bfc67a514d0bda24993e983f79aa8 Mon Sep 17 00:00:00 2001 +From: Farook Al-Sammarraie +Date: Fri, 31 Oct 2025 Subject: [PATCH] Remove dependency on math.h. This of course should be fine, but there are some misconfigured cross compilation builds that have issues with math.h. These builds should be fixed, but I also just generally like to avoid dependencies where possible, which it is in this case. ---- - include/fp16/fp16.h | 9 ++++----- - 1 file changed, 4 insertions(+), 5 deletions(-) - -diff --git a/include/fp16/fp16.h b/include/fp16/fp16.h -index 95fa0c2..142ef6a 100644 --- a/include/fp16/fp16.h -+++ b/include/fp16/fp16.h ++++ a/include/fp16/fp16.h @@ -4,10 +4,8 @@ #if defined(__cplusplus) && (__cplusplus >= 201103L) @@ -24,23 +18,17 @@ index 95fa0c2..142ef6a 100644 - #include #endif - #include -@@ -286,14 +284,15 @@ static inline uint16_t fp16_ieee_from_fp32_value(float f) { + #ifdef _MSC_VER +@@ -228,9 +226,10 @@ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); #endif -+ const uint32_t w = fp32_to_bits(f); +- float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +- + const uint32_t w = fp32_to_bits(f); + const float abs_f = fp32_from_bits(w & UINT32_C(0x7FFFFFFF)); - #if defined(_MSC_VER) && defined(_M_IX86_FP) && (_M_IX86_FP == 0) || defined(__GNUC__) && defined(__FLT_EVAL_METHOD__) && (__FLT_EVAL_METHOD__ != 0) -- const volatile float saturated_f = fabsf(f) * scale_to_inf; -+ const volatile float saturated_f = abs_f * scale_to_inf; - #else -- const float saturated_f = fabsf(f) * scale_to_inf; -+ const float saturated_f = abs_f * scale_to_inf; - #endif - float base = saturated_f * scale_to_zero; - -- const uint32_t w = fp32_to_bits(f); ++ float base = (abs_f * scale_to_inf) * scale_to_zero; ++ const uint32_t shl1_w = w + w; const uint32_t sign = w & UINT32_C(0x80000000); uint32_t bias = shl1_w & UINT32_C(0xFF000000); From 64066a478b51e7cb70f9a0017036e27d98af4882 Mon Sep 17 00:00:00 2001 From: Farook Al-Sammarraie Date: Fri, 31 Oct 2025 10:14:50 +0300 Subject: [PATCH 74/74] datasets now provide token limits as inputs to pipeline --- flutter/cpp/datasets/ifeval.cc | 1 + flutter/cpp/datasets/ifeval.h | 1 + flutter/cpp/datasets/mmlu_gen.cc | 1 + flutter/cpp/datasets/mmlu_gen.h | 1 + .../cpp/backend_tflite/llm_pipeline.cc | 14 +++++++++++--- 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/flutter/cpp/datasets/ifeval.cc b/flutter/cpp/datasets/ifeval.cc index 9328c2272..83f950623 100644 --- a/flutter/cpp/datasets/ifeval.cc +++ b/flutter/cpp/datasets/ifeval.cc @@ -62,6 +62,7 @@ std::vector IFEval::GetData(int sample_idx) { if (sample_idx < samples_.size()) { data.push_back(reinterpret_cast( const_cast*>(&(samples_[sample_idx]->input_tokens)))); + data.push_back(reinterpret_cast(const_cast(&token_limit_))); } return data; } diff --git a/flutter/cpp/datasets/ifeval.h b/flutter/cpp/datasets/ifeval.h index 2dd1e4747..7a3f15d21 100644 --- a/flutter/cpp/datasets/ifeval.h +++ b/flutter/cpp/datasets/ifeval.h @@ -81,6 +81,7 @@ class IFEval : public Dataset { std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; + static constexpr int token_limit_ = 1024; }; } // namespace mobile diff --git a/flutter/cpp/datasets/mmlu_gen.cc b/flutter/cpp/datasets/mmlu_gen.cc index 8f999ba30..921d47714 100644 --- a/flutter/cpp/datasets/mmlu_gen.cc +++ b/flutter/cpp/datasets/mmlu_gen.cc @@ -83,6 +83,7 @@ std::vector MmluGen::GetData(int sample_idx) { if (sample_idx < samples_.size()) { data.push_back(reinterpret_cast( const_cast*>(&(samples_[sample_idx]->input_tokens)))); + data.push_back(reinterpret_cast(const_cast(&token_limit_))); } return data; } diff --git a/flutter/cpp/datasets/mmlu_gen.h b/flutter/cpp/datasets/mmlu_gen.h index ce8189418..72f6c6a71 100644 --- a/flutter/cpp/datasets/mmlu_gen.h +++ b/flutter/cpp/datasets/mmlu_gen.h @@ -66,6 +66,7 @@ class MmluGen : public Dataset { std::unordered_set used_sample_ids_; std::set loaded_sample_ids_; std::unique_ptr sp_processor; + static constexpr int token_limit_ = 4; }; } // namespace mobile diff --git a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc index ed6ea5d8e..33150efeb 100644 --- a/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc +++ b/mobile_back_tflite/cpp/backend_tflite/llm_pipeline.cc @@ -213,19 +213,22 @@ mlperf_status_t LLMPipeline::backend_flush_queries( } // Return the number of inputs of the model. -// Only 1 input need to be provided, the tokens themselves. -// The other inputs are handled by the pipeline +// 2 inputs need to be provided manually, the tokens themselves. and a token +// count limit. The other inputs are handled by the pipeline int32_t LLMPipeline::backend_get_input_count(mlperf_backend_ptr_t backend_ptr) { - return 1; + return 2; } // Return the type of the ith input. +// All inputs are of they type [int32] mlperf_data_t LLMPipeline::backend_get_input_type( mlperf_backend_ptr_t backend_ptr, int32_t i) { return mlperf_data_t{mlperf_data_t::Int32, 0}; } // Set the data for ith input. +// 0: list of input tokens. +// 1: output token count limit. mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, int32_t batch_index, int32_t i, void* data) { @@ -233,6 +236,11 @@ mlperf_status_t LLMPipeline::backend_set_input(mlperf_backend_ptr_t backend_ptr, // Reset the tokens and kv caches from potential previous runs. backend_data->output_tokens.clear(); + if (i == 1) { + backend_data->max_output_tokens = *(reinterpret_cast(data)); + return MLPERF_SUCCESS; + } + for (auto& [_, vec] : backend_data->kv_cache) { std::fill(vec.begin(), vec.end(), 0.0f); }