diff --git a/README.md b/README.md index d7ce69a..8b186a1 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,422 @@ -# alphageometry -TODO(b/303301137): Add a description for your new project, explain what is -being released here, etc... Additional, the following sections are normally -expected for all releases. Feel free to add additional sections if appropriate -for your project. +# Solving Olympiad Geometry without Human Demonstrations -## Installation -Write instructions for how the user should install your code. The instructions -should ideally be valid when copy-pasted. You can combine this with the Usage -section if there's no separate installation step. +This repository contains the code necessary to +reproduce DDAR and AlphaGeometry, +the two geometry theorem provers +introduced in the [Nature 2024](https://www.nature.com/) paper: + +*
"Solving Olympiad Geometry without Human Demonstrations".
* + + +
+ + +
+fig1 +
+ + +## Dependencies + +For the instructions presented below, +we use Python 3.10.9, and dependencies with their exact +version numbers listed in `requirements.txt`. + +Our code depends on `meliad`, which is +not a registered package with `pip`. See instructions below +for how to manually install `meliad`. + +Note that one can still run the DDAR solver +without the `meliad` and `sentencepiece` dependencies. + +## Run the instructions + +All instructions in this `README.md` can be run in one go by: + +``` +bash run.sh +``` + +Below, we explain these instructions step-by-step. + +## Install dependencies, download weights and vocabulary. + +Installation is done in a virtual environment: + +``` +virtualenv -p python3 . +source ./bin/activate +pip install --require-hashes -r requirements.txt +``` + +Download weights and vocabulary: + +``` +bash download.sh +DATA=ag_ckpt_vocab +``` + +Finally, install `meliad` separately as it is not +registered with `pip`: + +``` +MELIAD_PATH=meliad_lib/meliad +mkdir -p $MELIAD_PATH +git clone https://github.com/google-research/meliad $MELIAD_PATH +PYTHONPATH=$PYTHONPATH:$MELIAD_PATH +``` + +## Set up common flags + +Before running the python scripts, +let us first prepare some commonly used flags. +The symbolic engine needs definitions and deduction rules to operate. +These definitions and rules are provided in two text files +`defs.txt` and `rules.txt`. + +```shell +DDAR_ARGS=( + --defs_file=$(pwd)/defs.txt \ + --rules_file=$(pwd)/rules.txt \ +); +``` + +Next, we define the flags relevant to the proof search. +To reproduce the simple examples below, +we use lightweight values for the proof search parameters: + +```shell +BATCH_SIZE=2 +BEAM_SIZE=2 +DEPTH=2 + +SEARCH_ARGS=( + --beam_size=$BEAM_SIZE + --search_depth=$DEPTH +) +``` + +NOTE: The results in our paper can be obtained by setting +`BATCH_SIZE=32`, `BEAM_SIZE=512`, `DEPTH=16` +as described in section Methods. +To stay under IMO time limits, 4 V100-GPUs and 250 CPU workers +are needed as shown in Extended Data - Figure 1. +Note that we also strip away other memory/speed optimizations +due to internal dependencies and to promote code clarity. + +Assume the downloaded checkpoint and vocabulary is placed in `DATA`, +and the installed `meliad` source code is at `MELIAD_PATH`. +We make use of the `gin` library to manage model configurations, +following `meliad` conventions. We now define the flags relevant to the +language model: + +```shell +LM_ARGS=( + --ckpt_path=$DATA \ + --vocab_path=$DATA/geometry.757.model + --gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \ + --gin_file=base_htrans.gin \ + --gin_file=size/medium_150M.gin \ + --gin_file=options/positions_t5.gin \ + --gin_file=options/lr_cosine_decay.gin \ + --gin_file=options/seq_1024_nocache.gin \ + --gin_file=geometry_150M_generate.gin \ + --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \ + --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \ + --gin_param=TransformerTaskConfig.sequence_length=128 \ + --gin_param=Trainer.restore_state_variables=False +); +``` + +TIP: Note that you can still run the DDAR solver +without defining `SEARCH_ARGS` and `LM_ARGS`. +In such case, simply disable the import of the `lm_inference` module +inside `alphageometry.py`. + +## Run DDAR + +The script loads a problem by reading a list of problems +from a text file and solves the specific problem in the list according +to its name. We pass these two pieces of information through the flags +`--problems_file` and `--problem_name`. +We use `--mode=ddar` to indicate that we want to use the DDAR solver. + +Below we showed this solver solving IMO 2000 P1: + +```shell +python -m alphageometry \ +--alsologtostderr \ +--problems_file=$(pwd)/imo_ag_30.txt \ +--problem_name=translated_imo_2000_p1 \ +--mode=ddar \ +"${DDAR_ARGS[@]}" +``` + +Expect the following output + +```shell +graph.py:468] translated_imo_2000_p1 +graph.py:469] a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q +ddar.py:41] Depth 1/1000 time = 1.7772269248962402 +ddar.py:41] Depth 2/1000 time = 5.63526177406311 +ddar.py:41] Depth 3/1000 time = 6.883412837982178 +ddar.py:41] Depth 4/1000 time = 10.275688409805298 +ddar.py:41] Depth 5/1000 time = 12.048273086547852 +alphageometry.py:190] +========================== + * From theorem premises: +A B G1 G2 M N C D E P Q : Points +AG_1 ⟂ AB [00] +BA ⟂ G_2B [01] +G_2M = G_2B [02] +G_1M = G_1A [03] + +... +[log omitted] +... + +036. ∠QEB = ∠(QP-EA) [46] & ∠(BE-QP) = ∠AEP [55] ⇒ ∠EQP = ∠QPE [56] +037. ∠PQE = ∠EPQ [56] ⇒ EP = EQ + +========================== +``` + +The output first includes a list of relevant premises that it uses, +and then proof steps that gradually build up the proof. +All predicates are numbered to track how they are derived +from the premises, and to show that the proof is fully justified. + +TIP: Additionally passing the flag `--out_file=path/to/output/text/file.txt` +will write the proof to a text file. + +Running on all problems in `imo_ag_30.txt` will yield solutions to +14 of them, as reported in Table 1 in our paper. + +## Run AlphaGeometry: + +As a simple example, we load `--problem_name=orthocenter` +from `--problem_file=examples.txt`. +This time, we pass `--mode=alphageometry` to use the AlphaGeometry solver +and pass the `SEARCH_ARGS` and `LM_ARGS` flags. + +```shell +python -m alphageometry \ +--alsologtostderr \ +--problems_file=$(pwd)/examples.txt \ +--problem_name=orthocenter \ +--mode=alphageometry \ +"${DDAR_ARGS[@]}" \ +"${SEARCH_ARGS[@]}" \ +"${LM_ARGS[@]}" +``` + +Expect the following output: + +```shell +... +[log omitted] +... +training_loop.py:725] Total parameters: 152072288 +training_loop.py:739] Total state size: 0 +training_loop.py:492] Training loop: creating task for mode beam_search + +graph.py:468] orthocenter +graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c +ddar.py:41] Depth 1/1000 time = 0.009987592697143555 branch = 4 +ddar.py:41] Depth 2/1000 time = 0.00672602653503418 branch = 0 +alphageometry.py:221] DD+AR failed to solve the problem. +alphageometry.py:457] Depth 0. There are 1 nodes to expand: +alphageometry.py:460] {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00 +alphageometry.py:465] Decoding from {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00 +... +[log omitted] +... +alphageometry.py:470] LM output (score=-1.102287): "e : C a c e 02 C b d e 03 ;" +alphageometry.py:471] Translation: "e = on_line e a c, on_line e b d" + +alphageometry.py:480] Solving: "a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c" +graph.py:468] +graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c +ddar.py:41] Depth 1/1000 time = 0.021120786666870117 +ddar.py:41] Depth 2/1000 time = 0.033370018005371094 +ddar.py:41] Depth 3/1000 time = 0.04297471046447754 +alphageometry.py:140] +========================== + * From theorem premises: +A B C D : Points +BD ⟂ AC [00] +CD ⟂ AB [01] + + * Auxiliary Constructions: +E : Points +E,B,D are collinear [02] +E,C,A are collinear [03] + + * Proof steps: +001. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEA = ∠CED [04] +002. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEC = ∠AED [05] +003. A,E,C are collinear [03] & E,B,D are collinear [02] & AC ⟂ BD [00] ⇒ EC ⟂ EB [06] +004. EC ⟂ EB [06] & CD ⟂ AB [01] ⇒ ∠(EC-BA) = ∠(EB-CD) [07] +005. E,C,A are collinear [03] & E,B,D are collinear [02] & ∠(EC-BA) = ∠(EB-CD) [07] ⇒ ∠BAE = ∠CDE [08] +006. ∠BEA = ∠CED [04] & ∠BAE = ∠CDE [08] (Similar Triangles)⇒ EB:EC = EA:ED [09] +007. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠BCE = ∠ADE [10] +008. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠EBC = ∠EAD [11] +009. ∠BCE = ∠ADE [10] & E,C,A are collinear [03] & E,B,D are collinear [02] & ∠EBC = ∠EAD [11] ⇒ AD ⟂ BC +========================== + +alphageometry.py:505] Solved. +``` + +NOTE: Point `H` is automatically renamed to `D`, +as the LM is trained on synthetic problems +where the points are named alphabetically, and so it expects +the same during test time. + +NOTE: In this implementation of AlphaGeometry, +we removed all optimizations that are dependent on +internal infrastructure, e.g., +parallelized model inference on multi GPUs, +parallelized DDAR on multiple CPUs, +parallel execution of LM and DDAR, +shared pool of CPU workers across different problems, etc. +We also removed some memory/speed optimizations and code +abstractions in favor of code clarity. + +As can be seen in the output, initially DDAR failed to solve the problem. +The LM proposes two auxiliary constructions (because `BATCH_SIZE=2`): + +* `e = eqdistance e c a b, eqdistance e b a c`, i.e., +construct `E` as the intersection of circle (center=C, radius=AB) and +circle (center=B, radius=AC). This construction has a score of `-1.186`. +* `e = on_line e a c, on_line e b d`, i.e., +`E` is the intersection of `AC` and `BD`. +This construction has a higher score (`-1.102287`) than the previous. + +Since the second construction has a higher score, DDAR attempted the second +construction first and found the solution right away. +The proof search therefore terminates and there is no second iteration. + +## Results + +Before attempting to reproduce the AlphaGeometry numbers in our paper, +please make sure to pass all tests in the prepared test suite: + +``` +bash run_tests.sh +``` + +Then, pass the corresponding values for `--problem_file` (column) +and `--mode` (row), and +iterate on all problems to obtain the following results: + +
+ +Number of solved problems: + +| | `imo_ag_30.txt` | `jgex_ag_231.txt` | +|----------|------------------|-------------------| +| `ddar` | 14 | 198 | +| `alphageometry` | 25 | 228 | + +
+ +## Source code description + +Files in this repository include python modules/scripts to run the solvers and +resource files necessary for the script to execute. We listed below +each of them and their description. + +| File name | Description | +|------------------------|------------------------------------------------------------------------------------| +| `geometry.py` | Implements nodes (Point, Line, Circle, etc) in the proof state graph. | +| `numericals.py` | Implements the numerical engine in the dynamic geometry environment. | +| `graph_utils.py` | Implements utilities for the proof state graph. | +| `graph.py` | Implements the proof state graph. | +| `problem.py` | Implements the classes that represent the problem premises, conclusion, DAG nodes. | +| `dd.py` | Implements DD and its traceback. | +| `ar.py` | Implements AR and its traceback. | +| `trace_back.py` | Implements the recursive traceback and dependency difference algorithm. | +| `ddar.py` | Implements the combination DD+AR. | +| `beam_search.py` | Implements beam decoding of a language model in JAX. | +| `models.py` | Implements the transformer model. | +| `transformer_layer.py` | Implements the transformer layer. | +| `decoder_stack.py` | Implements the transformer decoder stack. | +| `lm_inference.py` | Implements an interface to a trained LM to perform decoding. | +| `alphageometry.py` | Main script that loads problems, calls DD+AR or AlphaGeometry solver, and prints solutions. | +| `pretty.py` | Pretty formating the solutions output by solvers. | +| `*_test.py` | Tests for the corresponding module. | +| `download.sh` | Script to download model checkpoints and LM | +| `run.sh` | Script to execute instructions in README. | +| `run_tests.sh` | Script to execute the test suite. | + + +Resource files: + +| Resource file name | Description | +|------------------------|------------------------------------------------------------------------------------| +| `defs.txt` | Definitions of different geometric construction actions. | +| `rules.txt` | Deduction rules for DD. | +| `geometry_150M_generate.gin`| Gin config of the LM implemented in meliad. | +| `imo_ag_30.txt` | Problems in IMO-AG-30. | +| `jgex_ag_231.txt` | Problems in JGEX-AG-231. | -## Usage -Write example usage of your code. The instructions should ideally be valid when -copy-pasted, and will be used by your technical reviewer to verify that your -package functions correctly. ## Citing this work -Add citation details here, usually a pastable BibTeX snippet. +```bibtex +@Article{AlphaGeometryTrinh2023, + author = {Trinh, Trieu and Wu, Yuhuai and Le, Quoc and He, He and Luong, Thang}, + journal = {Nature}, + title = {Solving Olympiad Geometry without Human Demonstrations}, + year = {2024}, + doi = {10.1038/s41586-023-06747-5} +} +``` + +## Acknowledgements + +This research is a collaboration between the Google Brain team +(now Google Deepmind) and +the Computer Science Department of New York University. +We thank Rif A. Saurous, Denny Zhou, Christian Szegedy, Delesley Hutchins, +Thomas Kipf, Hieu Pham, Petar Veličković, Debidatta Dwibedi, +Kyunghyun Cho, Lerrel Pinto, Alfredo Canziani, +Thomas Wies, He He’s research group, +Evan Chen (the USA’s IMO team coach), +Mirek Olsak, Patrik Pak, +and all three Nature's referees for their help and support. + +The code of AlphaGeometry communicates with and/or references the following +separate libraries and packages: + +* [Abseil](https://github.com/abseil/abseil-py) +* [JAX](https://github.com/google/jax/) +* [matplotlib](https://matplotlib.org/) +* [NumPy](https://numpy.org) +* [SciPy](https://scipy.org) +* [TensorFlow](https://github.com/tensorflow/tensorflow) +* [Meliad](https://github.com/google-research/meliad) +* [Flax](https://github.com/google/flax) +* [Gin](https://github.com/google/gin-config) +* [T5](https://github.com/google-research/text-to-text-transfer-transformer) +* [SentencePiece](https://github.com/google/sentencepiece) -## License and disclaimer + + +We thank all their contributors and maintainers! + + +## Disclaimer + +This is not an officially supported Google product. + +This research code is provided "as-is" to the broader research community. +Google does not promise to maintain or otherwise support this code in any way. + +## Code License Copyright 2023 DeepMind Technologies Limited @@ -40,4 +435,11 @@ distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the licenses for the specific language governing permissions and limitations under those licenses. -This is not an official Google product. +## Model Parameters License + +The AlphaGeometry checkpoints and vocabulary are made available +under the terms of the Creative Commons Attribution 4.0 +International (CC BY 4.0) license. +You can find details at: +https://creativecommons.org/licenses/by/4.0/legalcode + diff --git a/alphageometry.py b/alphageometry.py new file mode 100644 index 0000000..06e4904 --- /dev/null +++ b/alphageometry.py @@ -0,0 +1,645 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Run DD+AR or AlphaGeometry solver. + +Please refer to README.md for detailed instructions. +""" + +import traceback + +from absl import app +from absl import flags +from absl import logging +import ddar +import graph as gh +import lm_inference as lm +import pretty as pt +import problem as pr + + +_GIN_SEARCH_PATHS = flags.DEFINE_list( + 'gin_search_paths', + ['third_party/py/meliad/transformer/configs'], + 'List of paths where the Gin config files are located.', +) +_GIN_FILE = flags.DEFINE_multi_string( + 'gin_file', ['base_htrans.gin'], 'List of Gin config files.' +) +_GIN_PARAM = flags.DEFINE_multi_string( + 'gin_param', None, 'Newline separated list of Gin parameter bindings.' +) + +_PROBLEMS_FILE = flags.DEFINE_string( + 'problems_file', + 'imo_ag_30.txt', + 'text file contains the problem strings. See imo_ag_30.txt for example.', +) +_PROBLEM_NAME = flags.DEFINE_string( + 'problem_name', + 'imo_2000_p1', + 'name of the problem to solve, must be in the problem_file.', +) +_MODE = flags.DEFINE_string( + 'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`') +_DEFS_FILE = flags.DEFINE_string( + 'defs_file', + 'defs.txt', + 'definitions of available constructions to state a problem.', +) +_RULES_FILE = flags.DEFINE_string( + 'rules_file', 'rules.txt', 'list of deduction rules used by DD.' +) +_CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.') +_VOCAB_PATH = flags.DEFINE_string( + 'vocab_path', '', 'path to the LM vocab file.' +) +_OUT_FILE = flags.DEFINE_string( + 'out_file', '', 'path to the solution output file.' +) # pylint: disable=line-too-long +_BEAM_SIZE = flags.DEFINE_integer( + 'beam_size', 1, 'beam size of the proof search.' +) # pylint: disable=line-too-long +_SEARCH_DEPTH = flags.DEFINE_integer( + 'search_depth', 1, 'search depth of the proof search.' +) # pylint: disable=line-too-long + +DEFINITIONS = None # contains definitions of construction actions +RULES = None # contains rules of deductions + + +def natural_language_statement(logical_statement: pr.Dependency) -> str: + """Convert logical_statement to natural language. + + Args: + logical_statement: pr.Dependency with .name and .args + + Returns: + a string of (pseudo) natural language of the predicate for human reader. + """ + names = [a.name.upper() for a in logical_statement.args] + names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names] + return pt.pretty_nl(logical_statement.name, names) + + +def proof_step_string( + proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool +) -> str: + """Translate proof to natural language. + + Args: + proof_step: pr.Dependency with .name and .args + refs: dict(hash: int) to keep track of derived predicates + last_step: boolean to keep track whether this is the last step. + + Returns: + a string of (pseudo) natural language of the proof step for human reader. + """ + premises, [conclusion] = proof_step + + premises_nl = ' & '.join( + [ + natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()]) + for p in premises + ] + ) + + if not premises: + premises_nl = 'similarly' + + refs[conclusion.hashed()] = len(refs) + + conclusion_nl = natural_language_statement(conclusion) + if not last_step: + conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()]) + + return f'{premises_nl} \u21d2 {conclusion_nl}' + + +def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None: + """Output the solution to out_file. + + Args: + g: gh.Graph object, containing the proof state. + p: pr.Problem object, containing the theorem. + out_file: file to write to, empty string to skip writing to file. + """ + setup, aux, proof_steps, refs = ddar.get_proof_steps( + g, p.goal, merge_trivials=False + ) + + solution = '\n==========================' + solution += '\n * From theorem premises:\n' + premises_nl = [] + for premises, [points] in setup: + solution += ' '.join([p.name.upper() for p in points]) + ' ' + if not premises: + continue + premises_nl += [ + natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()]) + for p in premises + ] + solution += ': Points\n' + '\n'.join(premises_nl) + + solution += '\n\n * Auxiliary Constructions:\n' + aux_premises_nl = [] + for premises, [points] in aux: + solution += ' '.join([p.name.upper() for p in points]) + ' ' + aux_premises_nl += [ + natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()]) + for p in premises + ] + solution += ': Points\n' + '\n'.join(aux_premises_nl) + + # some special case where the deduction rule has a well known name. + r2name = { + 'r32': '(SSS)', + 'r33': '(SAS)', + 'r34': '(Similar Triangles)', + 'r35': '(Similar Triangles)', + 'r36': '(ASA)', + 'r37': '(ASA)', + 'r38': '(Similar Triangles)', + 'r39': '(Similar Triangles)', + 'r40': '(Congruent Triangles)', + 'a00': '(Distance chase)', + 'a01': '(Ratio chase)', + 'a02': '(Angle chase)', + } + + solution += '\n\n * Proof steps:\n' + for i, step in enumerate(proof_steps): + _, [con] = step + nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1) + rule_name = r2name.get(con.rule_name, '') + nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ') + solution += '{:03}. '.format(i + 1) + nl + '\n' + + solution += '==========================\n' + logging.info(solution) + if out_file: + with open(out_file, 'w') as f: + f.write(solution) + logging.info('Solution written to %s.', out_file) + + +def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference: + lm.parse_gin_configuration( + _GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value + ) + + return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search') + + +def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool: + """Run DD+AR. + + Args: + g: gh.Graph object, containing the proof state. + p: pr.Problem object, containing the problem statement. + out_file: path to output file if solution is found. + + Returns: + Boolean, whether DD+AR finishes successfully. + """ + ddar.solve(g, RULES, p, max_level=1000) + + goal_args = g.names2nodes(p.goal.args) + if not g.check(p.goal.name, goal_args): + logging.info('DD+AR failed to solve the problem.') + return False + + write_solution(g, p, out_file) + return True + + +def translate_constrained_to_constructive( + point: str, name: str, args: list[str] +) -> tuple[str, list[str]]: + """Translate a predicate from constraint-based to construction-based. + + Args: + point: str: name of the new point + name: str: name of the predicate, e.g., perp, para, etc. + args: list[str]: list of predicate args. + + Returns: + (name, args): translated to constructive predicate. + """ + if name in ['T', 'perp']: + a, b, c, d = args + if point in [c, d]: + a, b, c, d = c, d, a, b + if point == b: + a, b = b, a + if point == d: + c, d = d, c + if a == c and a == point: + return 'on_dia', [a, b, d] + return 'on_tline', [a, b, c, d] + + elif name in ['P', 'para']: + a, b, c, d = args + if point in [c, d]: + a, b, c, d = c, d, a, b + if point == b: + a, b = b, a + return 'on_pline', [a, b, c, d] + + elif name in ['D', 'cong']: + a, b, c, d = args + if point in [c, d]: + a, b, c, d = c, d, a, b + if point == b: + a, b = b, a + if point == d: + c, d = d, c + if a == c and a == point: + return 'on_bline', [a, b, d] + if b in [c, d]: + if b == d: + c, d = d, c # pylint: disable=unused-variable + return 'on_circle', [a, b, d] + return 'eqdistance', [a, b, c, d] + + elif name in ['C', 'coll']: + a, b, c = args + if point == b: + a, b = b, a + if point == c: + a, b, c = c, a, b + return 'on_line', [a, b, c] + + elif name in ['^', 'eqangle']: + a, b, c, d, e, f = args + + if point in [d, e, f]: + a, b, c, d, e, f = d, e, f, a, b, c + + x, b, y, c, d = b, c, e, d, f + if point == b: + a, b, c, d = b, a, d, c + + if point == d and x == y: # x p x b = x c x p + return 'angle_bisector', [point, b, x, c] + + if point == x: + return 'eqangle3', [x, a, b, y, c, d] + + return 'on_aline', [a, x, b, c, y, d] + + elif name in ['cyclic', 'O']: + a, b, c = [x for x in args if x != point] + return 'on_circum', [point, a, b, c] + + return name, args + + +def check_valid_args(name: str, args: list[str]) -> bool: + """Check whether a predicate is grammarically correct. + + Args: + name: str: name of the predicate + args: list[str]: args of the predicate + + Returns: + bool: whether the predicate arg count is valid. + """ + if name == 'perp': + if len(args) != 4: + return False + a, b, c, d = args + if len({a, b}) < 2: + return False + if len({c, d}) < 2: + return False + elif name == 'para': + if len(args) != 4: + return False + a, b, c, d = args + if len({a, b, c, d}) < 4: + return False + elif name == 'cong': + if len(args) != 4: + return False + a, b, c, d = args + if len({a, b}) < 2: + return False + if len({c, d}) < 2: + return False + elif name == 'coll': + if len(args) != 3: + return False + a, b, c = args + if len({a, b, c}) < 3: + return False + elif name == 'cyclic': + if len(args) != 4: + return False + a, b, c, d = args + if len({a, b, c, d}) < 4: + return False + elif name == 'eqangle': + if len(args) != 8: + return False + a, b, c, d, e, f, g, h = args + if len({a, b, c, d}) < 3: + return False + if len({e, f, g, h}) < 3: + return False + return True + + +def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str: + """Whether a string of aux construction can be constructed. + + Args: + string: str: the string describing aux construction. + g: gh.Graph: the current proof state. + + Returns: + str: whether this construction is valid. If not, starts with "ERROR:". + """ + if string[-1] != ';': + return 'ERROR: must end with ;' + + head, prem_str = string.split(' : ') + point = head.strip() + + if len(point) != 1 or point == ' ': + return f'ERROR: invalid point name {point}' + + existing_points = [p.name for p in g.all_points()] + if point in existing_points: + return f'ERROR: point {point} already exists.' + + prem_toks = prem_str.split()[:-1] # remove the EOS ' ;' + prems = [[]] + + for i, tok in enumerate(prem_toks): + if tok.isdigit(): + if i < len(prem_toks) - 1: + prems.append([]) + else: + prems[-1].append(tok) + + if len(prems) > 2: + return 'ERROR: there cannot be more than two predicates.' + + clause_txt = point + ' = ' + constructions = [] + + for prem in prems: + name, *args = prem + + if point not in args: + return f'ERROR: {point} not found in predicate args.' + + if not check_valid_args(pt.map_symbol(name), args): + return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args) + + for a in args: + if a != point and a not in existing_points: + return f'ERROR: point {a} does not exist.' + + try: + name, args = translate_constrained_to_constructive(point, name, args) + except: # pylint: disable=bare-except + return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args) + + if name == 'on_aline': + if args.count(point) > 1: + return f'ERROR: on_aline involves twice {point}' + + constructions += [name + ' ' + ' '.join(args)] + + clause_txt += ', '.join(constructions) + clause = pr.Clause.from_txt(clause_txt) + + try: + g.copy().add_clause(clause, 0, DEFINITIONS) + except: # pylint: disable=bare-except + return 'ERROR: ' + traceback.format_exc() + + return clause_txt + + +def insert_aux_to_premise(pstring: str, auxstring: str) -> str: + """Insert auxiliary constructs from proof to premise. + + Args: + pstring: str: describing the problem to solve. + auxstring: str: describing the auxiliar construction. + + Returns: + str: new pstring with auxstring inserted before the conclusion. + """ + setup, goal = pstring.split(' ? ') + return setup + '; ' + auxstring + ' ? ' + goal + + +class BeamQueue: + """Keep only the top k objects according to their values.""" + + def __init__(self, max_size: int = 512): + self.queue = [] + self.max_size = max_size + + def add(self, node: object, val: float) -> None: + """Add a new node to this queue.""" + + if len(self.queue) < self.max_size: + self.queue.append((val, node)) + return + + # Find the minimum node: + min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1]) + + # replace it if the new node has higher value. + if val > min_val: + self.queue[min_idx] = (val, node) + + def __iter__(self): + for val, node in self.queue: + yield val, node + + def __len__(self) -> int: + return len(self.queue) + + +def run_alphageometry( + model: lm.LanguageModelInference, + p: pr.Problem, + search_depth: int, + beam_size: int, + out_file: str, +) -> bool: + """Simplified code to run AlphaGeometry proof search. + + We removed all optimizations that are infrastructure-dependent, e.g. + parallelized model inference on multi GPUs, + parallelized DD+AR on multiple CPUs, + parallel execution of LM and DD+AR, + shared pool of CPU workers across different problems, etc. + + Many other speed optimizations and abstractions are also removed to + better present the core structure of the proof search. + + Args: + model: Interface with inference-related endpoints to JAX's model. + p: pr.Problem object describing the problem to solve. + search_depth: max proof search depth. + beam_size: beam size of the proof search. + out_file: path to output file if solution is found. + + Returns: + boolean of whether this is solved. + """ + # translate the problem to a string of grammar that the LM is trained on. + string = p.setup_str_from_problem(DEFINITIONS) + # special tokens prompting the LM to generate auxiliary points. + string += ' {F1} x00' + # the graph to represent the proof state. + g, _ = gh.Graph.build_problem(p, DEFINITIONS) + + # First we run the symbolic engine DD+AR: + if run_ddar(g, p, out_file): + return True + + # beam search for the proof + # each node in the search tree is a 3-tuple: + # (, + # , + # ) + beam_queue = BeamQueue(max_size=beam_size) + # originally the beam search tree starts with a single node (a 3-tuple): + beam_queue.add( + node=(g, string, p.txt()), val=0.0 # value of the root node is simply 0. + ) + + for depth in range(search_depth): + logging.info( + 'Depth %s. There are %i nodes to expand:', depth, len(beam_queue) + ) + for _, (_, string, _) in beam_queue: + logging.info(string) + + new_queue = BeamQueue(max_size=beam_size) # to replace beam_queue. + + for prev_score, (g, string, pstring) in beam_queue: + logging.info('Decoding from %s', string) + outputs = model.beam_decode(string, eos_tokens=[';']) + + # translate lm output to the constructive language. + # so that we can update the graph representing proof states: + translations = [ + try_translate_constrained_to_construct(o, g) + for o in outputs['seqs_str'] + ] + + # couple the lm outputs with its translations + candidates = zip(outputs['seqs_str'], translations, outputs['scores']) + + # bring the highest scoring candidate first + candidates = reversed(list(candidates)) + + for lm_out, translation, score in candidates: + logging.info('LM output (score=%f): "%s"', score, lm_out) + logging.info('Translation: "%s"\n', translation) + + if translation.startswith('ERROR:'): + # the construction is invalid. + continue + + # Update the constructive statement of the problem with the aux point: + candidate_pstring = insert_aux_to_premise(pstring, translation) + + logging.info('Solving: "%s"', candidate_pstring) + p_new = pr.Problem.from_txt(candidate_pstring) + + # This is the new proof state graph representation: + g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS) + if run_ddar(g_new, p_new, out_file): + logging.info('Solved.') + return True + + # Add the candidate to the beam queue. + new_queue.add( + # The string for the new node is old_string + lm output + + # the special token asking for a new auxiliary point ' x00': + node=(g_new, string + ' ' + lm_out + ' x00', candidate_pstring), + # the score of each node is sum of score of all nodes + # on the path to itself. For beam search, there is no need to + # normalize according to path length because all nodes in beam + # is of the same path length. + val=prev_score + score, + ) + # Note that the queue only maintain at most beam_size nodes + # so this new node might possibly be dropped depending on its value. + + # replace the old queue with new queue before the new proof search depth. + beam_queue = new_queue + + return False + + +def main(_): + global DEFINITIONS + global RULES + + # definitions of terms used in our domain-specific language. + DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True) + # load inference rules used in DD. + RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True) + + # when using the language model, + # point names will be renamed to alphabetical a, b, c, d, e, ... + # instead of staying with their original names, + # in order to match the synthetic training data generation. + need_rename = _MODE.value != 'ddar' + + # load problems from the problems_file, + problems = pr.Problem.from_txt_file( + _PROBLEMS_FILE.value, to_dict=True, translate=need_rename + ) + + if _PROBLEM_NAME.value not in problems: + raise ValueError( + f'Problem name `{_PROBLEM_NAME.value}` ' + + f'not found in `{_PROBLEMS_FILE.value}`' + ) + + this_problem = problems[_PROBLEM_NAME.value] + + if _MODE.value == 'ddar': + g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS) + run_ddar(g, this_problem, _OUT_FILE.value) + + elif _MODE.value == 'alphageometry': + model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value) + run_alphageometry( + model, + this_problem, + _SEARCH_DEPTH.value, + _BEAM_SIZE.value, + _OUT_FILE.value, + ) + + else: + raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}') + + +if __name__ == '__main__': + app.run(main) diff --git a/alphageometry_test.py b/alphageometry_test.py new file mode 100644 index 0000000..dc1e48d --- /dev/null +++ b/alphageometry_test.py @@ -0,0 +1,103 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for alphageometry.py.""" + +import unittest + +from absl.testing import absltest +import alphageometry + + +class AlphaGeometryTest(unittest.TestCase): + + def test_translate_constrained_to_constructive(self): + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'T', list('addb') + ), + ('on_dia', ['d', 'b', 'a']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'T', list('adbc') + ), + ('on_tline', ['d', 'a', 'b', 'c']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'P', list('bcda') + ), + ('on_pline', ['d', 'a', 'b', 'c']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'D', list('bdcd') + ), + ('on_bline', ['d', 'c', 'b']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'D', list('bdcb') + ), + ('on_circle', ['d', 'b', 'c']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'D', list('bacd') + ), + ('eqdistance', ['d', 'c', 'b', 'a']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'C', list('bad') + ), + ('on_line', ['d', 'b', 'a']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'C', list('bad') + ), + ('on_line', ['d', 'b', 'a']), + ) + self.assertEqual( + alphageometry.translate_constrained_to_constructive( + 'd', 'O', list('abcd') + ), + ('on_circum', ['d', 'a', 'b', 'c']), + ) + + def test_insert_aux_to_premise(self): + pstring = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long + auxstring = 'e = on_line e a c, on_line e b d' + + target = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long + self.assertEqual( + alphageometry.insert_aux_to_premise(pstring, auxstring), target + ) + + def test_beam_queue(self): + beam_queue = alphageometry.BeamQueue(max_size=2) + + beam_queue.add('a', 1) + beam_queue.add('b', 2) + beam_queue.add('c', 3) + + beam_queue = list(beam_queue) + self.assertEqual(beam_queue, [(3, 'c'), (2, 'b')]) + + +if __name__ == '__main__': + absltest.main() diff --git a/ar.py b/ar.py new file mode 100644 index 0000000..01dd8cc --- /dev/null +++ b/ar.py @@ -0,0 +1,752 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implementing Algebraic Reasoning (AR).""" + +from collections import defaultdict # pylint: disable=g-importing-member +from fractions import Fraction as frac # pylint: disable=g-importing-member +from typing import Any, Generator + +import geometry as gm +import numpy as np +import problem as pr +from scipy import optimize + + +class InfQuotientError(Exception): + pass + + +def _gcd(x: int, y: int) -> int: + while y: + x, y = y, x % y + return x + + +def simplify(n: int, d: int) -> tuple[int, int]: + g = _gcd(n, d) + return (n // g, d // g) + + +# maximum denominator for a fraction. +MAX_DENOMINATOR = 1000000 + +# tolerance for fraction approximation +TOL = 1e-15 + + +def get_quotient(v: float) -> tuple[int, int]: + n = v + d = 1 + while abs(n - round(n)) > TOL: + d += 1 + n += v + if d > MAX_DENOMINATOR: + e = InfQuotientError(v) + raise e + + n = int(round(n)) + return simplify(n, d) + + +def fix_v(v: float) -> float: + n, d = get_quotient(v) + return n / d + + +def fix(e: dict[str, float]) -> dict[str, float]: + return {k: fix_v(v) for k, v in e.items()} + + +def frac_string(f: frac) -> str: + n, d = get_quotient(f) + return f'{n}/{d}' + + +def hashed(e: dict[str, float]) -> tuple[tuple[str, float], ...]: + return tuple(sorted(list(e.items()))) + + +def is_zero(e: dict[str, float]) -> bool: + return len(strip(e)) == 0 # pylint: disable=g-explicit-length-test + + +def strip(e: dict[str, float]) -> dict[str, float]: + return {v: c for v, c in e.items() if c != 0} + + +def plus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]: + e = dict(e1) + for v, c in e2.items(): + if v in e: + e[v] += c + else: + e[v] = c + return strip(e) + + +def plus_all(*es: list[dict[str, float]]) -> dict[str, float]: + result = {} + for e in es: + result = plus(result, e) + return result + + +def mult(e: dict[str, float], m: float) -> dict[str, float]: + return {v: m * c for v, c in e.items()} + + +def minus(e1: dict[str, float], e2: dict[str, float]) -> dict[str, float]: + return plus(e1, mult(e2, -1)) + + +def div(e1: dict[str, float], e2: dict[str, float]) -> float: + """Divide e1 by e2.""" + e1 = strip(e1) + e2 = strip(e2) + if set(e1.keys()) != set(e2.keys()): + return None + + n, d = None, None + + for v, c1 in e1.items(): + c2 = e2[v] # we want c1/c2 = n/d => c1*d=c2*n + if n is not None and c1 * d != c2 * n: + return None + n, d = c1, c2 + return frac(n) / frac(d) + + +def recon(e: dict[str, float], const: str) -> tuple[str, dict[str, float]]: + """Reconcile one variable in the expression e=0, given const.""" + e = strip(e) + if len(e) == 0: # pylint: disable=g-explicit-length-test + return None + + v0 = None + for v in e: + if v != const: + v0 = v + break + if v0 is None: + return v0 + + c0 = e.pop(v0) + return v0, {v: -c / c0 for v, c in e.items()} + + +def replace( + e: dict[str, float], v0: str, e0: dict[str, float] +) -> dict[str, float]: + if v0 not in e: + return e + e = dict(e) + m = e.pop(v0) + return plus(e, mult(e0, m)) + + +def comb2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]: + if len(elems) < 1: + return + for i, e1 in enumerate(elems[:-1]): + for e2 in elems[i + 1 :]: + yield e1, e2 + + +def perm2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]: + for e1, e2 in comb2(elems): + yield e1, e2 + yield e2, e1 + + +def chain2(elems: list[Any]) -> Generator[tuple[Any, Any], None, None]: + if len(elems) < 2: + return + for i, e1 in enumerate(elems[:-1]): + yield e1, elems[i + 1] + + +def update_groups( + groups1: list[Any], groups2: list[Any] +) -> tuple[list[Any], list[tuple[Any, Any]], list[list[Any]]]: + """Update groups of equivalent elements. + + Given groups1 = [set1, set2, set3, ..] + where all elems within each set_i is defined to be "equivalent" to each other. + (but not across the sets) + + Incoming groups2 = [set1, set2, ...] similar to set1 - it is the + additional equivalent information on elements in groups1. + + Return the new updated groups1 and the set of links + that make it that way. + + Example: + groups1 = [{1, 2}, {3, 4, 5}, {6, 7}] + groups2 = [{2, 3, 8}, {9, 10, 11}] + + => new groups1 and links: + groups1 = [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}] + links = (2, 3), (3, 8), (9, 10), (10, 11) + + Explain: since groups2 says 2 and 3 are equivalent (with {2, 3, 8}), + then {1, 2} and {3, 4, 5} in groups1 will be merged, + because 2 and 3 each belong to those 2 groups. + Additionally 8 also belong to this same group. + {3, 4, 5} is left alone, while {9, 10, 11} is a completely new set. + + The links to make this all happens is: + (2, 3): to merge {1, 2} and {3, 4, 5} + (3, 8): to link 8 into the merged({1, 2, 3, 4, 5}) + (9, 10) and (10, 11): to make the new group {9, 10, 11} + + Args: + groups1: a list of sets. + groups2: a list of sets. + + Returns: + groups1, links, history: result of the update. + """ + history = [] + links = [] + for g2 in groups2: + joins = [None] * len(groups1) # mark which one in groups1 is merged + merged_g1 = set() # merge them into this. + old = None # any elem in g2 that belong to any set in groups1 (old) + new = [] # all elem in g2 that is new + + for e in g2: + found = False + for i, g1 in enumerate(groups1): + if e not in g1: + continue + + found = True + if joins[i]: + continue + + joins[i] = True + merged_g1.update(g1) + + if old is not None: + links.append((old, e)) # link to make merging happen. + old = e + + if not found: # e is new! + new.append(e) + + # now chain elems in new together. + if old is not None and new: + links.append((old, new[0])) + merged_g1.update(new) + + links += chain2(new) + + new_groups1 = [] + if merged_g1: # put the merged_g1 in first + new_groups1.append(merged_g1) + + # put the remaining (unjoined) groups in + new_groups1 += [g1 for j, g1 in zip(joins, groups1) if not j] + + if old is None and new: + new_groups1 += [set(new)] + + groups1 = new_groups1 + history.append(groups1) + + return groups1, links, history + + +class Table: + """The coefficient matrix.""" + + def __init__(self, const: str = '1'): + self.const = const + self.v2e = {} + self.add_free(const) # the table {var: expression} + + # to cache what is already derived/inputted + self.eqs = set() + self.groups = [] # groups of equal pairs. + + # for why (linprog) + self.c = [] + self.v2i = {} # v -> index of row in A. + self.deps = [] # equal number of columns. + self.A = np.zeros([0, 0]) # pylint: disable=invalid-name + self.do_why = True + + def add_free(self, v: str) -> None: + self.v2e[v] = {v: frac(1)} + + def replace(self, v0: str, e0: dict[str, float]) -> None: + for v, e in list(self.v2e.items()): + self.v2e[v] = replace(e, v0, e0) + + def add_expr(self, vc: list[tuple[str, float]]) -> bool: + """Add a new equality, represented by the list of tuples vc=[(v, c), ..].""" + result = {} + free = [] + + for v, c in vc: + c = frac(c) + if v in self.v2e: + result = plus(result, mult(self.v2e[v], c)) + else: + free += [(v, c)] + + if free == []: # pylint: disable=g-explicit-bool-comparison + if is_zero(self.modulo(result)): + return False + result = recon(result, self.const) + if result is None: + return False + v, e = result + self.replace(v, e) + + elif len(free) == 1: + v, m = free[0] + self.v2e[v] = mult(result, frac(-1, m)) + + else: + dependent_v = None + for v, m in free: + if dependent_v is None and v != self.const: + dependent_v = (v, m) + continue + + self.add_free(v) + result = plus(result, {v: m}) + + v, m = dependent_v + self.v2e[v] = mult(result, frac(-1, m)) + + return True + + def register(self, vc: list[tuple[str, float]], dep: pr.Dependency) -> None: + """Register a new equality vc=[(v, c), ..] with traceback dependency dep.""" + result = plus_all(*[{v: c} for v, c in vc]) + if is_zero(result): + return + + vs, _ = zip(*vc) + for v in vs: + if v not in self.v2i: + self.v2i[v] = len(self.v2i) + + (m, n), l = self.A.shape, len(self.v2i) + if l > m: + self.A = np.concatenate([self.A, np.zeros([l - m, n])], 0) + + new_column = np.zeros([len(self.v2i), 2]) # N, 2 + for v, c in vc: + new_column[self.v2i[v], 0] += float(c) + new_column[self.v2i[v], 1] -= float(c) + + self.A = np.concatenate([self.A, new_column], 1) + self.c += [1.0, -1.0] + self.deps += [dep] + + def register2( + self, a: str, b: str, m: float, n: float, dep: pr.Dependency + ) -> None: + self.register([(a, m), (b, -n)], dep) + + def register3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None: + self.register([(a, 1), (b, -1), (self.const, -f)], dep) + + def register4( + self, a: str, b: str, c: str, d: str, dep: pr.Dependency + ) -> None: + self.register([(a, 1), (b, -1), (c, -1), (d, 1)], dep) + + def why(self, e: dict[str, float]) -> list[Any]: + """AR traceback == MILP.""" + if not self.do_why: + return [] + # why expr == 0? + # Solve min(c^Tx) s.t. A_eq * x = b_eq, x >= 0 + e = strip(e) + if not e: + return [] + + b_eq = [0] * len(self.v2i) + for v, c in e.items(): + b_eq[self.v2i[v]] += float(c) + + try: + x = optimize.linprog(c=self.c, A_eq=self.A, b_eq=b_eq, method='highs')[ + 'x' + ] + except: # pylint: disable=bare-except + x = optimize.linprog( + c=self.c, + A_eq=self.A, + b_eq=b_eq, + )['x'] + + deps = [] + for i, dep in enumerate(self.deps): + if x[2 * i] > 1e-12 or x[2 * i + 1] > 1e-12: + if dep not in deps: + deps.append(dep) + return deps + + def record_eq(self, v1: str, v2: str, v3: str, v4: str) -> None: + self.eqs.add((v1, v2, v3, v4)) + self.eqs.add((v2, v1, v4, v3)) + self.eqs.add((v3, v4, v1, v2)) + self.eqs.add((v4, v3, v2, v1)) + + def check_record_eq(self, v1: str, v2: str, v3: str, v4: str) -> bool: + if (v1, v2, v3, v4) in self.eqs: + return True + if (v2, v1, v4, v3) in self.eqs: + return True + if (v3, v4, v1, v2) in self.eqs: + return True + if (v4, v3, v2, v1) in self.eqs: + return True + return False + + def add_eq2( + self, a: str, b: str, m: float, n: float, dep: pr.Dependency + ) -> None: + # a/b = m/n + if not self.add_expr([(a, m), (b, -n)]): + return [] + self.register2(a, b, m, n, dep) + + def add_eq3(self, a: str, b: str, f: float, dep: pr.Dependency) -> None: + # a - b = f * constant + self.eqs.add((a, b, frac(f))) + self.eqs.add((b, a, frac(1 - f))) + + if not self.add_expr([(a, 1), (b, -1), (self.const, -f)]): + return [] + + self.register3(a, b, f, dep) + + def add_eq4(self, a: str, b: str, c: str, d: str, dep: pr.Dependency) -> None: + # a - b = c - d + self.record_eq(a, b, c, d) + self.record_eq(a, c, b, d) + + expr = list(minus({a: 1, b: -1}, {c: 1, d: -1}).items()) + + if not self.add_expr(expr): + return [] + + self.register4(a, b, c, d, dep) + self.groups, _, _ = update_groups( + self.groups, [{(a, b), (c, d)}, {(b, a), (d, c)}] + ) + + def pairs(self) -> Generator[list[tuple[str, str]], None, None]: + for v1, v2 in perm2(list(self.v2e.keys())): # pylint: disable=g-builtin-op + if v1 == self.const or v2 == self.const: + continue + yield v1, v2 + + def modulo(self, e: dict[str, float]) -> dict[str, float]: + return strip(e) + + def get_all_eqs( + self, + ) -> dict[tuple[tuple[str, float], ...], list[tuple[str, str]]]: + h2pairs = defaultdict(list) + for v1, v2 in self.pairs(): + e1, e2 = self.v2e[v1], self.v2e[v2] + e12 = minus(e1, e2) + h12 = hashed(self.modulo(e12)) + h2pairs[h12].append((v1, v2)) + return h2pairs + + def get_all_eqs_and_why( + self, return_quads: bool = True + ) -> Generator[Any, None, None]: + """Check all 4/3/2-permutations for new equalities.""" + groups = [] + + for h, vv in self.get_all_eqs().items(): + if h == (): # pylint: disable=g-explicit-bool-comparison + for v1, v2 in vv: + if (v1, v2) in self.eqs or (v2, v1) in self.eqs: + continue + self.eqs.add((v1, v2)) + # why v1 - v2 = e12 ? (note modulo(e12) == 0) + why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2])) + yield v1, v2, self.why(why_dict) + continue + + if len(h) == 1 and h[0][0] == self.const: + for v1, v2 in vv: + frac = h[0][1] # pylint: disable=redefined-outer-name + if (v1, v2, frac) in self.eqs: + continue + self.eqs.add((v1, v2, frac)) + # why v1 - v2 = e12 ? (note modulo(e12) == 0) + why_dict = minus({v1: 1, v2: -1}, minus(self.v2e[v1], self.v2e[v2])) + value = simplify(frac.numerator, frac.denominator) + yield v1, v2, value, self.why(why_dict) + continue + + groups.append(vv) + + if not return_quads: + return + + self.groups, links, _ = update_groups(self.groups, groups) + for (v1, v2), (v3, v4) in links: + if self.check_record_eq(v1, v2, v3, v4): + continue + e12 = minus(self.v2e[v1], self.v2e[v2]) + e34 = minus(self.v2e[v3], self.v2e[v4]) + + why_dict = minus( # why (v1-v2)-(v3-v4)=e12-e34? + minus({v1: 1, v2: -1}, {v3: 1, v4: -1}), minus(e12, e34) + ) + self.record_eq(v1, v2, v3, v4) + yield v1, v2, v3, v4, self.why(why_dict) + + +class GeometricTable(Table): + """Abstract class representing the coefficient matrix (table) A.""" + + def __init__(self, name: str = ''): + super().__init__(name) + self.v2obj = {} + + def get_name(self, objs: list[Any]) -> list[str]: + self.v2obj.update({o.name: o for o in objs}) + return [o.name for o in objs] + + def map2obj(self, names: list[str]) -> list[Any]: + return [self.v2obj[n] for n in names] + + def get_all_eqs_and_why( + self, return_quads: bool + ) -> Generator[Any, None, None]: + for out in super().get_all_eqs_and_why(return_quads): + if len(out) == 3: + x, y, why = out + x, y = self.map2obj([x, y]) + yield x, y, why + if len(out) == 4: + x, y, f, why = out + x, y = self.map2obj([x, y]) + yield x, y, f, why + if len(out) == 5: + a, b, x, y, why = out + a, b, x, y = self.map2obj([a, b, x, y]) + yield a, b, x, y, why + + +class RatioTable(GeometricTable): + """Coefficient matrix A for log(distance).""" + + def __init__(self, name: str = ''): + name = name or '1' + super().__init__(name) + self.one = self.const + + def add_eq(self, l1: gm.Length, l2: gm.Length, dep: pr.Dependency) -> None: + l1, l2 = self.get_name([l1, l2]) + return super().add_eq3(l1, l2, 0.0, dep) + + def add_const_ratio( + self, l1: gm.Length, l2: gm.Length, m: float, n: float, dep: pr.Dependency + ) -> None: + l1, l2 = self.get_name([l1, l2]) + return super().add_eq2(l1, l2, m, n, dep) + + def add_eqratio( + self, + l1: gm.Length, + l2: gm.Length, + l3: gm.Length, + l4: gm.Length, + dep: pr.Dependency, + ) -> None: + l1, l2, l3, l4 = self.get_name([l1, l2, l3, l4]) + return self.add_eq4(l1, l2, l3, l4, dep) + + def get_all_eqs_and_why(self) -> Generator[Any, None, None]: + return super().get_all_eqs_and_why(True) + + +class AngleTable(GeometricTable): + """Coefficient matrix A for slope(direction).""" + + def __init__(self, name: str = ''): + name = name or 'pi' + super().__init__(name) + self.pi = self.const + + def modulo(self, e: dict[str, float]) -> dict[str, float]: + e = strip(e) + if self.pi not in e: + return super().modulo(e) + + e[self.pi] = e[self.pi] % 1 + return strip(e) + + def add_para( + self, d1: gm.Direction, d2: gm.Direction, dep: pr.Dependency + ) -> None: + return self.add_const_angle(d1, d2, 0, dep) + + def add_const_angle( + self, d1: gm.Direction, d2: gm.Direction, ang: float, dep: pr.Dependency + ) -> None: + if ang and d2._obj.num > d1._obj.num: # pylint: disable=protected-access + d1, d2 = d2, d1 + ang = 180 - ang + + d1, d2 = self.get_name([d1, d2]) + + num, den = simplify(ang, 180) + ang = frac(int(num), int(den)) + return super().add_eq3(d1, d2, ang, dep) + + def add_eqangle( + self, + d1: gm.Direction, + d2: gm.Direction, + d3: gm.Direction, + d4: gm.Direction, + dep: pr.Dependency, + ) -> None: + """Add the inequality d1-d2=d3-d4.""" + # Use string as variables. + l1, l2, l3, l4 = [d._obj.num for d in [d1, d2, d3, d4]] # pylint: disable=protected-access + d1, d2, d3, d4 = self.get_name([d1, d2, d3, d4]) + ang1 = {d1: 1, d2: -1} + ang2 = {d3: 1, d4: -1} + + if l2 > l1: + ang1 = plus({self.pi: 1}, ang1) + if l4 > l3: + ang2 = plus({self.pi: 1}, ang2) + + ang12 = minus(ang1, ang2) + self.record_eq(d1, d2, d3, d4) + self.record_eq(d1, d3, d2, d4) + + expr = list(ang12.items()) + if not self.add_expr(expr): + return [] + + self.register(expr, dep) + + def get_all_eqs_and_why(self) -> Generator[Any, None, None]: + return super().get_all_eqs_and_why(True) + + +class DistanceTable(GeometricTable): + """Coefficient matrix A for position(point, line).""" + + def __init__(self, name: str = ''): + name = name or '1:1' + self.merged = {} + self.ratios = set() + super().__init__(name) + + def pairs(self) -> Generator[tuple[str, str], None, None]: + l2vs = defaultdict(list) + for v in list(self.v2e.keys()): # pylint: disable=g-builtin-op + if v == self.const: + continue + l, p = v.split(':') + l2vs[l].append(p) + + for l, ps in l2vs.items(): + for p1, p2 in perm2(ps): + yield l + ':' + p1, l + ':' + p2 + + def name(self, l: gm.Line, p: gm.Point) -> str: + v = l.name + ':' + p.name + self.v2obj[v] = (l, p) + return v + + def map2obj(self, names: list[str]) -> list[gm.Point]: + return [self.v2obj[n][1] for n in names] + + def add_cong( + self, + l12: gm.Line, + l34: gm.Line, + p1: gm.Point, + p2: gm.Point, + p3: gm.Point, + p4: gm.Point, + dep: pr.Dependency, + ) -> None: + """Add that distance between p1 and p2 (on l12) == p3 and p4 (on l34).""" + if p2.num > p1.num: + p1, p2 = p2, p1 + if p4.num > p3.num: + p3, p4 = p4, p3 + + p1 = self.name(l12, p1) + p2 = self.name(l12, p2) + p3 = self.name(l34, p3) + p4 = self.name(l34, p4) + return super().add_eq4(p1, p2, p3, p4, dep) + + def get_all_eqs_and_why(self) -> Generator[Any, None, None]: + for x in super().get_all_eqs_and_why(True): + yield x + + # Now we figure out all the const ratios. + h2pairs = defaultdict(list) + for v1, v2 in self.pairs(): + if (v1, v2) in self.merged: + continue + e1, e2 = self.v2e[v1], self.v2e[v2] + e12 = minus(e1, e2) + h12 = hashed(e12) + h2pairs[h12].append((v1, v2, e12)) + + for (_, vves1), (_, vves2) in perm2(list(h2pairs.items())): + v1, v2, e12 = vves1[0] + for v1_, v2_, _ in vves1[1:]: + self.merged[(v1_, v2_)] = (v1, v2) + + v3, v4, e34 = vves2[0] + for v3_, v4_, _ in vves2[1:]: + self.merged[(v3_, v4_)] = (v3, v4) + + if (v1, v2, v3, v4) in self.ratios: + continue + + d12 = div(e12, e34) + if d12 is None or d12 > 1 or d12 < 0: + continue + + self.ratios.add((v1, v2, v3, v4)) + self.ratios.add((v2, v1, v4, v3)) + + n, d = d12.numerator, d12.denominator + + # (v1 - v2) * d = (v3 - v4) * n + why_dict = minus( + minus({v1: d, v2: -d}, {v3: n, v4: -n}), + minus(mult(e12, d), mult(e34, n)), # there is no modulo, so this is 0 + ) + + v1, v2, v3, v4 = self.map2obj([v1, v2, v3, v4]) + yield v1, v2, v3, v4, abs(n), abs(d), self.why(why_dict) diff --git a/ar_test.py b/ar_test.py new file mode 100644 index 0000000..df68e90 --- /dev/null +++ b/ar_test.py @@ -0,0 +1,204 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for ar.py.""" +import unittest + +from absl.testing import absltest +import ar +import graph as gh +import problem as pr + + +class ARTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) + cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) + + def test_update_groups(self): + """Test for update_groups.""" + groups1 = [{1, 2}, {3, 4, 5}, {6, 7}] + groups2 = [{2, 3, 8}, {9, 10, 11}] + + _, links, history = ar.update_groups(groups1, groups2) + self.assertEqual( + history, + [ + [{1, 2, 3, 4, 5, 8}, {6, 7}], + [{1, 2, 3, 4, 5, 8}, {6, 7}, {9, 10, 11}], + ], + ) + self.assertEqual(links, [(2, 3), (3, 8), (9, 10), (10, 11)]) + + groups1 = [{1, 2}, {3, 4}, {5, 6}, {7, 8}] + groups2 = [{2, 3, 8, 9, 10}, {3, 6, 11}] + + _, links, history = ar.update_groups(groups1, groups2) + self.assertEqual( + history, + [ + [{1, 2, 3, 4, 7, 8, 9, 10}, {5, 6}], + [{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}], + ], + ) + self.assertEqual(links, [(2, 3), (3, 8), (8, 9), (9, 10), (3, 6), (6, 11)]) + + groups1 = [] + groups2 = [{1, 2}, {3, 4}, {5, 6}, {2, 3}] + + _, links, history = ar.update_groups(groups1, groups2) + self.assertEqual( + history, + [ + [{1, 2}], + [{1, 2}, {3, 4}], + [{1, 2}, {3, 4}, {5, 6}], + [{1, 2, 3, 4}, {5, 6}], + ], + ) + self.assertEqual(links, [(1, 2), (3, 4), (5, 6), (2, 3)]) + + def test_generic_table_simple(self): + tb = ar.Table() + + # If a-b = b-c & d-a = c-d + tb.add_eq4('a', 'b', 'b', 'c', 'fact1') + tb.add_eq4('d', 'a', 'c', 'd', 'fact2') + tb.add_eq4('x', 'y', 'z', 't', 'fact3') # distractor fact + + # Then b=d, because {fact1, fact2} but not fact3. + result = list(tb.get_all_eqs_and_why()) + self.assertIn(('b', 'd', ['fact1', 'fact2']), result) + + def test_angle_table_inbisector_exbisector(self): + """Test that AR can figure out bisector & ex-bisector are perpendicular.""" + # Load the scenario that we have cd is bisector of acb and + # ce is the ex-bisector of acb. + p = pr.Problem.from_txt( + 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?' + ' perp d c c e' + ) + g, _ = gh.Graph.build_problem(p, ARTest.defs) + + # Create an external angle table: + tb = ar.AngleTable('pi') + + # Add bisector & ex-bisector facts into the table: + ca, cd, cb, ce = g.names2nodes(['d(ac)', 'd(cd)', 'd(bc)', 'd(ce)']) + tb.add_eqangle(ca, cd, cd, cb, 'fact1') + tb.add_eqangle(ce, ca, cb, ce, 'fact2') + + # Add a distractor fact to make sure traceback does not include this fact + ab = g.names2nodes(['d(ab)'])[0] + tb.add_eqangle(ab, cb, cb, ca, 'fact3') + + # Check for all new equalities + result = list(tb.get_all_eqs_and_why()) + + # halfpi is represented as a tuple (1, 2) + halfpi = (1, 2) + + # check that cd-ce == halfpi and this is because fact1 & fact2, not fact3 + self.assertCountEqual( + result, + [ + (cd, ce, halfpi, ['fact1', 'fact2']), + (ce, cd, halfpi, ['fact1', 'fact2']), + ], + ) + + def test_angle_table_equilateral_triangle(self): + """Test that AR can figure out triangles with 3 equal angles => each is pi/3.""" + # Load an equaliteral scenario + p = pr.Problem.from_txt('a b c = ieq_triangle ? cong a b a c') + g, _ = gh.Graph.build_problem(p, ARTest.defs) + + # Add two eqangles facts because ieq_triangle only add congruent sides + a, b, c = g.names2nodes('abc') + g.add_eqangle([a, b, b, c, b, c, c, a], pr.EmptyDependency(0, None)) + g.add_eqangle([b, c, c, a, c, a, a, b], pr.EmptyDependency(0, None)) + + # Create an external angle table: + tb = ar.AngleTable('pi') + + # Add the fact that there are three equal angles + ab, bc, ca = g.names2nodes(['d(ab)', 'd(bc)', 'd(ac)']) + tb.add_eqangle(ab, bc, bc, ca, 'fact1') + tb.add_eqangle(bc, ca, ca, ab, 'fact2') + + # Now check for all new equalities + result = list(tb.get_all_eqs_and_why()) + result = [(x.name, y.name, z, t) for x, y, z, t in result] + + # 1/3 pi is represented as a tuple angle_60 + angle_60 = (1, 3) + angle_120 = (2, 3) + + # check that angles constants are created and figured out: + self.assertCountEqual( + result, + [ + ('d(bc)', 'd(ac)', angle_120, ['fact1', 'fact2']), + ('d(ab)', 'd(bc)', angle_120, ['fact1', 'fact2']), + ('d(ac)', 'd(ab)', angle_120, ['fact1', 'fact2']), + ('d(ac)', 'd(bc)', angle_60, ['fact1', 'fact2']), + ('d(bc)', 'd(ab)', angle_60, ['fact1', 'fact2']), + ('d(ab)', 'd(ac)', angle_60, ['fact1', 'fact2']), + ], + ) + + def test_incenter_excenter_touchpoints(self): + """Test that AR can figure out incenter/excenter touchpoints are equidistant to midpoint.""" + + p = pr.Problem.from_txt( + 'a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e =' + ' excenter2 a b c ? perp d c c e', + translate=False, + ) + g, _ = gh.Graph.build_problem(p, ARTest.defs) + + a, b, c, ab, bc, ca, d1, d2, d3, e1, e2, e3 = g.names2nodes( + ['a', 'b', 'c', 'ab', 'bc', 'ac', 'd1', 'd2', 'd3', 'e1', 'e2', 'e3'] + ) + + # Create an external distance table: + tb = ar.DistanceTable() + + # DD can figure out the following facts, + # we manually add them to AR. + tb.add_cong(ab, ca, a, d3, a, d2, 'fact1') + tb.add_cong(ab, ca, a, e3, a, e2, 'fact2') + tb.add_cong(ca, bc, c, d2, c, d1, 'fact5') + tb.add_cong(ca, bc, c, e2, c, e1, 'fact6') + tb.add_cong(bc, ab, b, d1, b, d3, 'fact3') + tb.add_cong(bc, ab, b, e1, b, e3, 'fact4') + + # Now we check whether tb has figured out that + # distance(b, d1) == distance(e1, c) + + # linear comb exprssion of each variables: + b = tb.v2e['bc:b'] + c = tb.v2e['bc:c'] + d1 = tb.v2e['bc:d1'] + e1 = tb.v2e['bc:e1'] + + self.assertEqual(ar.minus(d1, b), ar.minus(c, e1)) + + +if __name__ == '__main__': + absltest.main() diff --git a/beam_search.py b/beam_search.py new file mode 100644 index 0000000..a11c81c --- /dev/null +++ b/beam_search.py @@ -0,0 +1,463 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Fast decoding routines for inference from a trained model. + +Modified https://github.com/google/flax/blob/main/examples/wmt/decode.py +to acommodate + +(a) continued decoding from a previous beam cache. +(b) init with with a single beam and then expand into beam_size beams. +""" + +from typing import Any + +import flax +import jax +from jax import lax +import jax.numpy as jnp +import numpy as np + + +# Constants +# "Effective negative infinity" constant for masking in beam search. +NEG_INF = np.array(-1.0e7) + +# Beam search parameters +BEAM_SEARCH_DEFAULT_ALPHA = 0.6 +MAX_DECODE_LEN = 32 + +# Brevity penalty parameters +BREVITY_LEN_BIAS_NUMERATOR = 5.0 +BREVITY_LEN_BIAS_DENOMINATOR = 6.0 + + +def brevity_penalty(alpha: float, length: int): + """Brevity penalty function for beam search penalizing short sequences. + + Args: + alpha: float: brevity-penalty scaling parameter. + length: int: length of considered sequence. + + Returns: + Brevity penalty score as jax scalar. + """ + return jnp.power( + ((BREVITY_LEN_BIAS_NUMERATOR + length) / BREVITY_LEN_BIAS_DENOMINATOR), + alpha, + ) + + +# Beam handling utility functions: + + +def add_beam_dim(x: jnp.ndarray, beam_size: int) -> jnp.ndarray: + """Creates new beam dimension in non-scalar array and tiles into it.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + x = jnp.expand_dims(x, axis=1) + tile_dims = [1] * x.ndim + tile_dims[1] = beam_size + return jnp.tile(x, tile_dims) + + +def add_beam_dim_cache( + cache: tuple[dict[str, jnp.ndarray], ...], beam_size: int +) -> tuple[dict[str, jnp.ndarray], ...]: + """Creates new beam dimension in non-scalar array and tiles into it.""" + new_cache = [] + + for layer in cache: + new_layer = {} + for key, x in layer.items(): + if key in ['keys', 'vals']: + x = add_beam_dim(x, beam_size) + new_layer[key] = x + new_cache.append(new_layer) + + return tuple(new_cache) + + +def flatten_beam_dim(x): + """Flattens the first two dimensions of a non-scalar array.""" + if x.ndim < 2: # ignore scalars (e.g. cache index) + return x + return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) + + +def unflatten_beam_dim(x, batch_size, beam_size): + """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + assert batch_size * beam_size == x.shape[0] + return x.reshape((batch_size, beam_size) + x.shape[1:]) + + +def flat_batch_beam_expand(x, beam_size): + """Expands the each batch item by beam_size in batch_dimension.""" + return flatten_beam_dim(add_beam_dim(x, beam_size)) + + +def gather_beams(nested, beam_indices, batch_size, new_beam_size): + """Gathers the beam slices indexed by beam_indices into new beam array. + + Args: + nested: pytree of arrays or scalars (the latter ignored). + beam_indices: array of beam_indices + batch_size: int: size of batch. + new_beam_size: int: size of _new_ beam dimension. + + Returns: + New pytree with new beam arrays. + [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] + """ + batch_indices = jnp.reshape( + jnp.arange(batch_size * new_beam_size) // new_beam_size, + (batch_size, new_beam_size), + ) + + def gather_fn(x): + if x.ndim == 0: # ignore scalars (e.g. cache index) + return x + else: + return x[batch_indices, beam_indices] + + return jax.tree_util.tree_map(gather_fn, nested) + + +def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): + """Gathers the top-k beam slices given by score_or_log_prob array. + + Args: + nested: pytree of arrays or scalars (the latter ignored). + score_or_log_prob: [batch_size, old_beam_size] array of values to sort by + for top-k selection of beam slices. + batch_size: int: size of batch. + new_beam_size: int: size of _new_ top-k selected beam dimension + + Returns: + New pytree with new beam arrays containing top k new_beam_size slices. + [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] + """ + _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size) + topk_indices = jnp.flip(topk_indices, axis=1) + return gather_beams(nested, topk_indices, batch_size, new_beam_size) + + +def apply_on_cache(fn, cache, *args, **kwargs): + """Apply fn(val) only when key is 'keys' or 'val'.""" + new_cache = [] + for layer in cache: + new_layer = {} + for key, val in layer.items(): + if key in ['keys', 'values', 'current_index', 'relative_position_bias']: + val = fn(val, *args, **kwargs) + new_layer[key] = val + new_cache.append(new_layer) + return tuple(new_cache) + + +# Beam search state: + + +@flax.struct.dataclass +class BeamState: + """Holds beam search state data.""" + + # The position of the decoding loop in the length dimension. + cur_index: jax.Array # scalar int32: current decoded length index + # The active sequence log probabilities and finished sequence scores. + live_logprobs: jax.Array # float32: [batch_size, beam_size] + finished_scores: jax.Array # float32: [batch_size, beam_size] + # The current active-beam-searching and finished sequences. + live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len] + finished_seqs: jax.Array # int32: [batch_size, beam_size, + # max_decode_len] + # Records which of the 'finished_seqs' is occupied and not a filler slot. + finished_flags: jax.Array # bool: [batch_size, beam_size] + # The current state of the autoregressive decoding caches. + cache: Any # Any pytree of arrays, e.g. flax attention Cache object + + +def beam_init(seed_token, batch_size, beam_size, max_decode_len, cache): + """Initializes the beam search state data structure.""" + cur_index0 = jnp.array(0) + live_logprobs0 = jnp.tile( + jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] + ) + finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF + + live_seqs0 = jnp.concatenate( + [ + jnp.reshape(seed_token, (batch_size, beam_size, 1)), + jnp.zeros((batch_size, beam_size, max_decode_len - 1), jnp.int32), + ], + axis=-1, + ) # (batch, beam, max_decode_len) + + finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) + beam_cache0 = apply_on_cache(lambda x: jnp.expand_dims(x, axis=0), cache) + return BeamState( + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) + + +# Beam search routine: + + +def beam_search_flat( + seed_token, + cache, + tokens_to_logits, + alpha=BEAM_SEARCH_DEFAULT_ALPHA, + eos=None, + max_decode_len=MAX_DECODE_LEN, + mask=None, +): + """Beam search for LM. + + inputs and cache is already flat! i.e. first dimention == batch*beam. + + Args: + seed_token: array: [beam_size, 1] int32 sequence of tokens. + cache: flax attention cache. + tokens_to_logits: fast autoregressive decoder function taking single token + slices and cache and returning next-token logits and updated cache. + alpha: float: scaling factor for brevity penalty. + eos: array: [vocab] 1 for end-of-sentence tokens, 0 for not. + max_decode_len: int: maximum length of decoded translations. + mask: array: [vocab] binary mask for vocab. 1 to keep the prob, 0 to set the + prob := 0. + + Returns: + Tuple of: + [beam_size, max_decode_len] top-scoring sequences + [beam_size] beam-search scores. + """ + # We liberally annotate shape information for clarity below. + batch_size, beam_size = 1, seed_token.shape[0] + mask = mask.reshape((1, 1, -1)) + eos = eos.reshape((1, 1, -1)) + mask_bias = (1 - mask) * NEG_INF + + # initialize beam search state + beam_search_init_state = beam_init( + seed_token, batch_size, beam_size, max_decode_len, cache + ) + + def beam_search_loop_cond_fn(state): + """Beam search loop termination condition.""" + # Have we reached max decoding length? + not_at_end = state.cur_index < max_decode_len - 1 + + # Is no further progress in the beam search possible? + # Get the best possible scores from alive sequences. + min_brevity_penalty = brevity_penalty(alpha, max_decode_len) + best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty + # Get the worst scores from finished sequences. + worst_finished_scores = jnp.min( + state.finished_scores, axis=1, keepdims=True + ) + # Mask out scores from slots without any actual finished sequences. + worst_finished_scores = jnp.where( + state.finished_flags, worst_finished_scores, NEG_INF + ) + # If no best possible live score is better than current worst finished + # scores, the search cannot improve the finished set further. + search_terminated = jnp.all(worst_finished_scores > best_live_scores) + + # If we're not at the max decode length, and the search hasn't terminated, + # continue looping. + return not_at_end & (~search_terminated) + + def beam_search_loop_body_fn(state): + """Beam search loop state update function.""" + # Collect the current position slice along length to feed the fast + # autoregressive decoder model. Flatten the beam dimension into batch + # dimension for feeding into the model. + # --> [batch * beam, 1] + flat_ids = flatten_beam_dim( + lax.dynamic_slice( + state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) + ) + ) + # Flatten beam dimension into batch to be compatible with model. + # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} + flat_cache = apply_on_cache(flatten_beam_dim, state.cache) + + # Call fast-decoder model on current tokens to get next-position logits. + # --> [batch * beam, vocab] + flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache) + + # unflatten beam dimension + # [batch * beam, vocab] --> [batch, beam, vocab] + logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) + + # Unflatten beam dimension in attention cache arrays + # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} + new_cache = apply_on_cache( + unflatten_beam_dim, new_flat_cache, batch_size, beam_size + ) + + # Gather log probabilities from logits + candidate_log_probs = jax.nn.log_softmax(logits) + # Add new logprobs to existing prefix logprobs. + # --> [batch, beam, vocab] + log_probs = candidate_log_probs + jnp.expand_dims( + state.live_logprobs, axis=2 + ) + + # We'll need the vocab size, gather it from the log probability dimension. + vocab_size = log_probs.shape[2] + + # mask away some tokens. + log_probs += mask_bias # [batch,beam,vocab]+[1,1,vocab] + + # Each item in batch has beam_size * vocab_size candidate sequences. + # For each item, get the top 2*k candidates with the highest log- + # probabilities. We gather the top 2*K beams here so that even if the best + # K sequences reach EOS simultaneously, we have another K sequences + # remaining to continue the live beam search. + beams_to_keep = 2 * beam_size + # Flatten beam and vocab dimensions. + flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) + # Gather the top 2*K scores from _all_ beams. + # --> [batch, 2*beams], [batch, 2*beams] + topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep) + # Recover the beam index by floor division. + topk_beam_indices = topk_indices // vocab_size + # Gather 2*k top beams. + # --> [batch, 2*beams, length] + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) + + # Append the most probable 2*K token IDs to the top 2*K sequences + # Recover token id by modulo division and expand Id array for broadcasting. + # --> [batch, 2*beams, 1] + topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) + # Update sequences for the 2*K top-k new sequences. + # --> [batch, 2*beams, length] + topk_seq = lax.dynamic_update_slice( + topk_seq, topk_ids, (0, 0, state.cur_index + 1) + ) + + # Update LIVE (in-progress) sequences: + # Did any of these sequences reach an end marker? + # --> [batch, 2*beams] + last_token = topk_seq[:, :, state.cur_index + 1] + last_token = jax.nn.one_hot(last_token, vocab_size, dtype=jnp.bfloat16) + + # any([batch, 2b, vocab] * [1, 1, vocab], axis=-1) == [batch, 2b] + newly_finished = jnp.any(last_token * eos, axis=-1) + + # To prevent these newly finished sequences from being added to the LIVE + # set of active beam search sequences, set their log probs to a very large + # negative value. + new_log_probs = topk_log_probs + newly_finished * NEG_INF + # Determine the top k beam indices (from top 2*k beams) from log probs. + # --> [batch, beams] + _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) + new_topk_indices = jnp.flip(new_topk_indices, axis=1) + # Gather the top k beams (from top 2*k beams). + # --> [batch, beams, length], [batch, beams] + top_alive_seq, top_alive_log_probs = gather_beams( + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) + + # Determine the top k beam indices from the original set of all beams. + # --> [batch, beams] + top_alive_indices = gather_beams( + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) + # With these, gather the top k beam-associated caches. + # --> {[batch, beams, ...], ...} + top_alive_cache = apply_on_cache( + gather_beams, new_cache, top_alive_indices, batch_size, beam_size + ) + + # Update FINISHED (reached end of sentence) sequences: + # Calculate new seq scores from log probabilities. + new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) + # Mask out the still unfinished sequences by adding large negative value. + # --> [batch, 2*beams] + new_scores += (~newly_finished) * NEG_INF + + # Combine sequences, scores, and flags along the beam dimension and compare + # new finished sequence scores to existing finished scores and select the + # best from the new set of beams. + finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] + [state.finished_seqs, topk_seq], axis=1 + ) + finished_scores = jnp.concatenate( # --> [batch, 3*beams] + [state.finished_scores, new_scores], axis=1 + ) + finished_flags = jnp.concatenate( # --> [batch, 3*beams] + [state.finished_flags, newly_finished], axis=1 + ) + # --> [batch, beams, length], [batch, beams], [batch, beams] + top_finished_seq, top_finished_scores, top_finished_flags = ( + gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + ) + + return BeamState( + cur_index=state.cur_index + 1, + live_logprobs=top_alive_log_probs, + finished_scores=top_finished_scores, + live_seqs=top_alive_seq, + finished_seqs=top_finished_seq, + finished_flags=top_finished_flags, + cache=top_alive_cache, + ) + + # Run while loop and get final beam search state. + final_state = lax.while_loop( + beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state + ) + + # Account for the edge-case where there are no finished sequences for a + # particular batch item. If so, return live sequences for that batch item. + # --> [batch] + none_finished = jnp.any(final_state.finished_flags, axis=1) + # --> [batch, beams, length] + finished_seqs = jnp.where( + none_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) + # --> [batch, beams] + finished_scores = jnp.where( + none_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) + + finished_seqs = jnp.reshape(finished_seqs, (beam_size, max_decode_len)) + finished_scores = jnp.reshape(finished_scores, (beam_size,)) + + final_cache = apply_on_cache(flatten_beam_dim, final_state.cache) + return finished_seqs, finished_scores, final_cache diff --git a/dd.py b/dd.py new file mode 100644 index 0000000..19325ff --- /dev/null +++ b/dd.py @@ -0,0 +1,1156 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements Deductive Database (DD).""" + +# pylint: disable=g-multiple-import,g-importing-member +from collections import defaultdict +import time +from typing import Any, Callable, Generator + +import geometry as gm +import graph as gh +import graph_utils as utils +import numericals as nm +import problem as pr +from problem import Dependency, EmptyDependency + + +def intersect1(set1: set[Any], set2: set[Any]) -> Any: + for x in set1: + if x in set2: + return x + return None + + +def diff_point(l: gm.Line, a: gm.Point) -> gm.Point: + for x in l.neighbors(gm.Point): + if x != a: + return x + return None + + +# pylint: disable=protected-access +# pylint: disable=unused-argument + + +def match_eqratio_eqratio_eqratio( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u.""" + for m1 in g.type2nodes[gm.Value]: + for m2 in g.type2nodes[gm.Value]: + rats1 = [] + for rat in m1.neighbors(gm.Ratio): + l1, l2 = rat.lengths + if l1 is None or l2 is None: + continue + rats1.append((l1, l2)) + + rats2 = [] + for rat in m2.neighbors(gm.Ratio): + l1, l2 = rat.lengths + if l1 is None or l2 is None: + continue + rats2.append((l1, l2)) + + pairs = [] + for (l1, l2), (l3, l4) in utils.cross(rats1, rats2): + if l2 == l3: + pairs.append((l1, l2, l4)) + + for (l1, l12, l2), (l3, l34, l4) in utils.comb2(pairs): + if (l1, l12, l2) == (l3, l34, l4): + continue + if l1 == l2 or l3 == l4: + continue + if l1 == l12 or l12 == l2 or l3 == l34 or l4 == l34: + continue + # d12 - d1 = d34 - d3 = m1 + # d2 - d12 = d4 - d34 = m2 + # => d2 - d1 = d4 - d3 (= m1+m2) + a, b = g.two_points_of_length(l1) + c, d = g.two_points_of_length(l12) + m, n = g.two_points_of_length(l3) + p, q = g.two_points_of_length(l34) + # eqangle a b c d m n p q + e, f = g.two_points_of_length(l2) + r, u = g.two_points_of_length(l4) + yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u])) + + +def match_eqangle_eqangle_eqangle( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u.""" + for m1 in g.type2nodes[gm.Measure]: + for m2 in g.type2nodes[gm.Measure]: + angs1 = [] + for ang in m1.neighbors(gm.Angle): + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + angs1.append((d1, d2)) + + angs2 = [] + for ang in m2.neighbors(gm.Angle): + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + angs2.append((d1, d2)) + + pairs = [] + for (d1, d2), (d3, d4) in utils.cross(angs1, angs2): + if d2 == d3: + pairs.append((d1, d2, d4)) + + for (d1, d12, d2), (d3, d34, d4) in utils.comb2(pairs): + if (d1, d12, d2) == (d3, d34, d4): + continue + if d1 == d2 or d3 == d4: + continue + if d1 == d12 or d12 == d2 or d3 == d34 or d4 == d34: + continue + # d12 - d1 = d34 - d3 = m1 + # d2 - d12 = d4 - d34 = m2 + # => d2 - d1 = d4 - d3 + a, b = g.two_points_on_direction(d1) + c, d = g.two_points_on_direction(d12) + m, n = g.two_points_on_direction(d3) + p, q = g.two_points_on_direction(d34) + # eqangle a b c d m n p q + e, f = g.two_points_on_direction(d2) + r, u = g.two_points_on_direction(d4) + yield dict(zip('abcdefmnpqru', [a, b, c, d, e, f, m, n, p, q, r, u])) + + +def match_perp_perp_npara_eqangle( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H.""" + dpairs = [] + for ang in g.vhalfpi.neighbors(gm.Angle): + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + dpairs.append((d1, d2)) + + for (d1, d2), (d3, d4) in utils.comb2(dpairs): + a, b = g.two_points_on_direction(d1) + c, d = g.two_points_on_direction(d2) + m, n = g.two_points_on_direction(d3) + p, q = g.two_points_on_direction(d4) + if g.check_npara([a, b, m, n]): + if ({a, b}, {c, d}) == ({m, n}, {p, q}): + continue + if ({a, b}, {c, d}) == ({p, q}, {m, n}): + continue + + yield dict(zip('ABCDEFGH', [a, b, c, d, m, n, p, q])) + + +def match_circle_coll_eqangle_midp( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C.""" + for p, a, b, c in g.all_circles(): + ab = g._get_line(a, b) + if ab is None: + continue + if ab.val is None: + continue + ac = g._get_line(a, c) + if ac is None: + continue + if ac.val is None: + continue + pb = g._get_line(p, b) + if pb is None: + continue + if pb.val is None: + continue + + bc = g._get_line(b, c) + if bc is None: + continue + bc_points = bc.neighbors(gm.Point, return_set=True) + + anga, _ = g._get_angle(ab.val, ac.val) + + for angp in pb.val.neighbors(gm.Angle): + if not g.is_equal(anga, angp): + continue + + _, d = angp.directions + for l in d.neighbors(gm.Line): + l_points = l.neighbors(gm.Point, return_set=True) + m = intersect1(bc_points, l_points) + if m is not None: + yield dict(zip('ABCMO', [a, b, c, m, p])) + + +def match_midp_perp_cong( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match midp M A B, perp O M A B => cong O A O B.""" + for m, a, b in g.all_midps(): + ab = g._get_line(a, b) + for l in m.neighbors(gm.Line): + if g.check_perpl(l, ab): + for o in l.neighbors(gm.Point): + if o != m: + yield dict(zip('ABMO', [a, b, m, o])) + + +def match_cyclic_eqangle_cong( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q.""" + for c in g.type2nodes[gm.Circle]: + ps = c.neighbors(gm.Point) + for (a, b, c), (x, y, z) in utils.comb2(list(utils.perm3(ps))): + if {a, b, c} == {x, y, z}: + continue + if g.check_eqangle([c, a, c, b, z, x, z, y]): + yield dict(zip('ABCPQR', [a, b, c, x, y, z])) + + +def match_circle_eqangle_perp( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match circle O A B C, eqangle A X A B C A C B => perp O A A X.""" + for p, a, b, c in g.all_circles(): + ca = g._get_line(c, a) + if ca is None: + continue + cb = g._get_line(c, b) + if cb is None: + continue + ab = g._get_line(a, b) + if ab is None: + continue + + if ca.val is None: + continue + if cb.val is None: + continue + if ab.val is None: + continue + + c_ang, _ = g._get_angle(cb.val, ca.val) + if c_ang is None: + continue + + for ang in ab.val.neighbors(gm.Angle): + if g.is_equal(ang, c_ang): + _, d = ang.directions + for l in d.neighbors(gm.Line): + if a not in l.neighbors(gm.Point): + continue + x = diff_point(l, a) + if x is None: + continue + yield dict(zip('OABCX', [p, a, b, c, x])) + break + + +def match_circle_perp_eqangle( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match circle O A B C, perp O A A X => eqangle A X A B C A C B.""" + for p, a, b, c in g.all_circles(): + pa = g._get_line(p, a) + if pa is None: + continue + if pa.val is None: + continue + for l in a.neighbors(gm.Line): + if g.check_perpl(pa, l): + x = diff_point(l, a) + if x is not None: + yield dict(zip('OABCX', [p, a, b, c, x])) + + +def match_perp_perp_ncoll_para( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match perp A B C D, perp C D E F, ncoll A B E => para A B E F.""" + d2d = defaultdict(list) + for ang in g.vhalfpi.neighbors(gm.Angle): + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + d2d[d1] += [d2] + d2d[d2] += [d1] + + for x, ys in d2d.items(): + if len(ys) < 2: + continue + c, d = g.two_points_on_direction(x) + for y1, y2 in utils.comb2(ys): + a, b = g.two_points_on_direction(y1) + e, f = g.two_points_on_direction(y2) + if nm.check_ncoll([a.num, b.num, e.num]): + yield dict(zip('ABCDEF', [a, b, c, d, e, f])) + + +def match_eqangle6_ncoll_cong( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 A O A B B A B O, ncoll O A B => cong O A O B.""" + for a in g.type2nodes[gm.Point]: + for b, c in utils.comb2(g.type2nodes[gm.Point]): + if a == b or a == c: + continue + if g.check_eqangle([b, a, b, c, c, b, c, a]): + if g.check_ncoll([a, b, c]): + yield dict(zip('OAB', [a, b, c])) + + +def match_eqangle_perp_perp( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle A B P Q C D U V, perp P Q U V => perp A B C D.""" + for ang in g.vhalfpi.neighbors(gm.Angle): + # d1 perp d2 + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + for d3, d4 in utils.comb2(g.type2nodes[gm.Direction]): + if d1 == d3 or d2 == d4: + continue + # if d1 - d3 = d2 - d4 => d3 perp d4 + a13, a31 = g._get_angle(d1, d3) + a24, a42 = g._get_angle(d2, d4) + if a13 is None or a31 is None or a24 is None or a42 is None: + continue + if g.is_equal(a13, a24) and g.is_equal(a31, a42): + a, b = g.two_points_on_direction(d1) + c, d = g.two_points_on_direction(d2) + m, n = g.two_points_on_direction(d3) + p, q = g.two_points_on_direction(d4) + yield dict(zip('ABCDPQUV', [m, n, p, q, a, b, c, d])) + + +def match_eqangle_ncoll_cyclic( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q.""" + for l1, l2, l3, l4 in g.all_eqangles_distinct_linepairss(): + if len(set([l1, l2, l3, l4])) < 4: + continue # they all must be distinct. + + p1s = l1.neighbors(gm.Point, return_set=True) + p2s = l2.neighbors(gm.Point, return_set=True) + p3s = l3.neighbors(gm.Point, return_set=True) + p4s = l4.neighbors(gm.Point, return_set=True) + + p = intersect1(p1s, p2s) + if not p: + continue + q = intersect1(p3s, p4s) + if not q: + continue + a = intersect1(p1s, p3s) + if not a: + continue + b = intersect1(p2s, p4s) + if not b: + continue + if len(set([a, b, p, q])) < 4: + continue + + if not g.check_ncoll([a, b, p, q]): + continue + + yield dict(zip('ABPQ', [a, b, p, q])) + + +def match_eqangle_para( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle A B P Q C D P Q => para A B C D.""" + for measure in g.type2nodes[gm.Measure]: + angs = measure.neighbors(gm.Angle) + d12, d21 = defaultdict(list), defaultdict(list) + for ang in angs: + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + d12[d1].append(d2) + d21[d2].append(d1) + + for d1, d2s in d12.items(): + a, b = g.two_points_on_direction(d1) + for d2, d3 in utils.comb2(d2s): + c, d = g.two_points_on_direction(d2) + e, f = g.two_points_on_direction(d3) + yield dict(zip('ABCDPQ', [c, d, e, f, a, b])) + + +def match_cyclic_eqangle( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match cyclic A B P Q => eqangle P A P B Q A Q B.""" + record = set() + for a, b, c, d in g_matcher('cyclic'): + if (a, b, c, d) in record: + continue + record.add((a, b, c, d)) + record.add((a, b, d, c)) + record.add((b, a, c, d)) + record.add((b, a, d, c)) + yield dict(zip('ABPQ', [a, b, c, d])) + + +def rotate_simtri( + a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point +) -> Generator[tuple[gm.Point, ...], None, None]: + """Rotate points around for similar triangle predicates.""" + yield (z, y, x, c, b, a) + for p in [ + (b, c, a, y, z, x), + (c, a, b, z, x, y), + (x, y, z, a, b, c), + (y, z, x, b, c, a), + (z, x, y, c, a, b), + ]: + yield p + yield p[::-1] + + +def match_cong_cong_cong_cyclic( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match cong O A O B, cong O B O C, cong O C O D => cyclic A B C D.""" + for l in g.type2nodes[gm.Length]: + p2p = defaultdict(list) + for s in l.neighbors(gm.Segment): + a, b = s.points + p2p[a].append(b) + p2p[b].append(a) + + for p, ps in p2p.items(): + if len(ps) >= 4: + for a, b, c, d in utils.comb4(ps): + yield dict(zip('OABCD', [p, a, b, c, d])) + + +def match_cong_cong_cong_ncoll_contri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R.""" + record = set() + for a, b, p, q in g_matcher('cong'): + for c in g.type2nodes[gm.Point]: + for r in g.type2nodes[gm.Point]: + if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]): + continue + if not g.check_ncoll([a, b, c]): + continue + if g.check_cong([b, c, q, r]) and g.check_cong([c, a, r, p]): + record.add((a, b, c, p, q, r)) + yield dict(zip('ABCPQR', [a, b, c, p, q, r])) + + +def match_cong_cong_eqangle6_ncoll_contri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R.""" + record = set() + for a, b, p, q in g_matcher('cong'): + for c in g.type2nodes[gm.Point]: + if c in (a, b): + continue + for r in g.type2nodes[gm.Point]: + if r in (p, q): + continue + + in_record = False + for x in [ + (c, b, a, r, q, p), + (p, q, r, a, b, c), + (r, q, p, c, b, a), + ]: + if x in record: + in_record = True + break + + if in_record: + continue + + if not g.check_cong([b, c, q, r]): + continue + if not g.check_ncoll([a, b, c]): + continue + + if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num): + if g.check_eqangle([b, a, b, c, q, p, q, r]): + record.add((a, b, c, p, q, r)) + yield dict(zip('ABCPQR', [a, b, c, p, q, r])) + else: + if g.check_eqangle([b, a, b, c, q, r, q, p]): + record.add((a, b, c, p, q, r)) + yield dict(zip('ABCPQR', [a, b, c, p, q, r])) + + +def match_eqratio6_eqangle6_ncoll_simtri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R.""" + enums = g_matcher('eqratio6') + + record = set() + for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]): + continue + if not g.check_ncoll([a, b, c]): + continue + + if nm.same_clock(a.num, b.num, c.num, p.num, q.num, r.num): + if g.check_eqangle([b, a, b, c, q, p, q, r]): + record.add((a, b, c, p, q, r)) + yield dict(zip('ABCPQR', [a, b, c, p, q, r])) + elif g.check_eqangle([b, a, b, c, q, r, q, p]): + record.add((a, b, c, p, q, r)) + yield dict(zip('ABCPQR', [a, b, c, p, q, r])) + + +def match_eqangle6_eqangle6_ncoll_simtri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R.""" + enums = g_matcher('eqangle6') + + record = set() + for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]): + continue + if not g.check_eqangle([c, a, c, b, r, p, r, q]): + continue + if not g.check_ncoll([a, b, c]): + continue + + mapping = dict(zip('ABCPQR', [a, b, c, p, q, r])) + record.add((a, b, c, p, q, r)) + yield mapping + + +def match_eqratio6_eqratio6_ncoll_simtri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R.""" + enums = g_matcher('eqratio6') + + record = set() + for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]): + continue + if not g.check_eqratio([c, a, c, b, r, p, r, q]): + continue + if not g.check_ncoll([a, b, c]): + continue + + mapping = dict(zip('ABCPQR', [a, b, c, p, q, r])) + record.add((a, b, c, p, q, r)) + yield mapping + + +def match_eqangle6_eqangle6_ncoll_simtri2( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R.""" + enums = g_matcher('eqangle6') + + record = set() + for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_simtri(a, b, c, p, q, r)]): + continue + if not g.check_eqangle([c, a, c, b, r, q, r, p]): + continue + if not g.check_ncoll([a, b, c]): + continue + + mapping = dict(zip('ABCPQR', [a, b, c, p, q, r])) + record.add((a, b, c, p, q, r)) + yield mapping + + +def rotate_contri( + a: gm.Point, b: gm.Point, c: gm.Point, x: gm.Point, y: gm.Point, z: gm.Point +) -> Generator[tuple[gm.Point, ...], None, None]: + for p in [(b, a, c, y, x, z), (x, y, z, a, b, c), (y, x, z, b, a, c)]: + yield p + + +def match_eqangle6_eqangle6_ncoll_cong_contri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R.""" + enums = g_matcher('eqangle6') + + record = set() + for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if not g.check_cong([a, b, p, q]): + continue + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_contri(a, b, c, p, q, r)]): + continue + if not g.check_eqangle([c, a, c, b, r, p, r, q]): + continue + + if not g.check_ncoll([a, b, c]): + continue + + mapping = dict(zip('ABCPQR', [a, b, c, p, q, r])) + record.add((a, b, c, p, q, r)) + yield mapping + + +def match_eqratio6_eqratio6_ncoll_cong_contri( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R.""" + enums = g_matcher('eqratio6') + + record = set() + for b, a, b, c, q, p, q, r in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if not g.check_cong([a, b, p, q]): + continue + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_contri(a, b, c, p, q, r)]): + continue + if not g.check_eqratio([c, a, c, b, r, p, r, q]): + continue + + if not g.check_ncoll([a, b, c]): + continue + + mapping = dict(zip('ABCPQR', [a, b, c, p, q, r])) + record.add((a, b, c, p, q, r)) + yield mapping + + +def match_eqangle6_eqangle6_ncoll_cong_contri2( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R.""" + enums = g_matcher('eqangle6') + + record = set() + for b, a, b, c, q, r, q, p in enums: # pylint: disable=redeclared-assigned-name,unused-variable + if not g.check_cong([a, b, p, q]): + continue + if (a, b, c) == (p, q, r): + continue + if any([x in record for x in rotate_contri(a, b, c, p, q, r)]): + continue + if not g.check_eqangle([c, a, c, b, r, q, r, p]): + continue + if not g.check_ncoll([a, b, c]): + continue + + mapping = dict(zip('ABCPQR', [a, b, c, p, q, r])) + record.add((a, b, c, p, q, r)) + yield mapping + + +def match_eqratio6_coll_ncoll_eqangle6( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c.""" + records = set() + for b, d, c in g_matcher('coll'): + for a in g.all_points(): + if g.check_coll([a, b, c]): + continue + if (a, b, d, c) in records or (a, c, d, b) in records: + continue + records.add((a, b, d, c)) + + if g.check_eqratio([d, b, d, c, a, b, a, c]): + yield dict(zip('abcd', [a, b, c, d])) + + +def match_eqangle6_coll_ncoll_eqratio6( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c.""" + records = set() + for b, d, c in g_matcher('coll'): + for a in g.all_points(): + if g.check_coll([a, b, c]): + continue + if (a, b, d, c) in records or (a, c, d, b) in records: + continue + records.add((a, b, d, c)) + + if g.check_eqangle([a, b, a, d, a, d, a, c]): + yield dict(zip('abcd', [a, b, c, d])) + + +def match_eqangle6_ncoll_cyclic( + g: gh.Graph, + g_matcher: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem, +) -> Generator[dict[str, gm.Point], None, None]: + """Match eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q.""" + for a, b, a, c, x, y, x, z in g_matcher('eqangle6'): # pylint: disable=redeclared-assigned-name,unused-variable + if (b, c) != (y, z) or a == x: + continue + if nm.check_ncoll([x.num for x in [a, b, c, x]]): + yield dict(zip('ABPQ', [b, c, a, x])) + + +def match_all( + name: str, g: gh.Graph +) -> Generator[tuple[gm.Point, ...], None, None]: + """Match all instances of a certain relation.""" + if name in ['ncoll', 'npara', 'nperp']: + return [] + if name == 'coll': + return g.all_colls() + if name == 'para': + return g.all_paras() + if name == 'perp': + return g.all_perps() + if name == 'cong': + return g.all_congs() + if name == 'eqangle': + return g.all_eqangles_8points() + if name == 'eqangle6': + return g.all_eqangles_6points() + if name == 'eqratio': + return g.all_eqratios_8points() + if name == 'eqratio6': + return g.all_eqratios_6points() + if name == 'cyclic': + return g.all_cyclics() + if name == 'midp': + return g.all_midps() + if name == 'circle': + return g.all_circles() + raise ValueError(f'Unrecognize {name}') + + +def cache_match( + graph: gh.Graph, +) -> Callable[str, list[tuple[gm.Point, ...]]]: + """Cache throughout one single BFS level.""" + cache = {} + + def match_fn(name: str) -> list[tuple[gm.Point, ...]]: + if name in cache: + return cache[name] + + result = list(match_all(name, graph)) + cache[name] = result + return result + + return match_fn + + +def try_to_map( + clause_enum: list[tuple[pr.Clause, list[tuple[gm.Point, ...]]]], + mapping: dict[str, gm.Point], +) -> Generator[dict[str, gm.Point], None, None]: + """Recursively try to match the remaining points given current mapping.""" + if not clause_enum: + yield mapping + return + + clause, enum = clause_enum[0] + for points in enum: + mpcpy = dict(mapping) + + fail = False + for p, a in zip(points, clause.args): + if a in mpcpy and mpcpy[a] != p or p in mpcpy and mpcpy[p] != a: + fail = True + break + mpcpy[a] = p + mpcpy[p] = a + + if fail: + continue + + for m in try_to_map(clause_enum[1:], mpcpy): + yield m + + +def match_generic( + g: gh.Graph, + cache: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem +) -> Generator[dict[str, gm.Point], None, None]: + """Match any generic rule that is not one of the above match_*() rules.""" + clause2enum = {} + + clauses = [] + numerical_checks = [] + for clause in theorem.premise: + if clause.name in ['ncoll', 'npara', 'nperp', 'sameside']: + numerical_checks.append(clause) + continue + + enum = cache(clause.name) + if len(enum) == 0: # pylint: disable=g-explicit-length-test + return 0 + + clause2enum[clause] = enum + clauses.append((len(set(clause.args)), clause)) + + clauses = sorted(clauses, key=lambda x: x[0], reverse=True) + _, clauses = zip(*clauses) + + for mapping in try_to_map([(c, clause2enum[c]) for c in clauses], {}): + if not mapping: + continue + + checks_ok = True + for check in numerical_checks: + args = [mapping[a] for a in check.args] + if check.name == 'ncoll': + checks_ok = g.check_ncoll(args) + elif check.name == 'npara': + checks_ok = g.check_npara(args) + elif check.name == 'nperp': + checks_ok = g.check_nperp(args) + elif check.name == 'sameside': + checks_ok = g.check_sameside(args) + if not checks_ok: + break + if not checks_ok: + continue + + yield mapping + + +BUILT_IN_FNS = { + 'cong_cong_cong_cyclic': match_cong_cong_cong_cyclic, + 'cong_cong_cong_ncoll_contri*': match_cong_cong_cong_ncoll_contri, + 'cong_cong_eqangle6_ncoll_contri*': match_cong_cong_eqangle6_ncoll_contri, + 'eqangle6_eqangle6_ncoll_simtri': match_eqangle6_eqangle6_ncoll_simtri, + 'eqangle6_eqangle6_ncoll_cong_contri': ( + match_eqangle6_eqangle6_ncoll_cong_contri + ), # pylint: disable=line-too-long + 'eqangle6_eqangle6_ncoll_simtri2': match_eqangle6_eqangle6_ncoll_simtri2, + 'eqangle6_eqangle6_ncoll_cong_contri2': ( + match_eqangle6_eqangle6_ncoll_cong_contri2 + ), # pylint: disable=line-too-long + 'eqratio6_eqratio6_ncoll_simtri*': match_eqratio6_eqratio6_ncoll_simtri, + 'eqratio6_eqratio6_ncoll_cong_contri*': ( + match_eqratio6_eqratio6_ncoll_cong_contri + ), # pylint: disable=line-too-long + 'eqangle_para': match_eqangle_para, + 'eqangle_ncoll_cyclic': match_eqangle_ncoll_cyclic, + 'eqratio6_eqangle6_ncoll_simtri*': match_eqratio6_eqangle6_ncoll_simtri, + 'eqangle_perp_perp': match_eqangle_perp_perp, + 'eqangle6_ncoll_cong': match_eqangle6_ncoll_cong, + 'perp_perp_ncoll_para': match_perp_perp_ncoll_para, + 'circle_perp_eqangle': match_circle_perp_eqangle, + 'circle_eqangle_perp': match_circle_eqangle_perp, + 'cyclic_eqangle_cong': match_cyclic_eqangle_cong, + 'midp_perp_cong': match_midp_perp_cong, + 'perp_perp_npara_eqangle': match_perp_perp_npara_eqangle, + 'cyclic_eqangle': match_cyclic_eqangle, + 'eqangle_eqangle_eqangle': match_eqangle_eqangle_eqangle, + 'eqratio_eqratio_eqratio': match_eqratio_eqratio_eqratio, + 'eqratio6_coll_ncoll_eqangle6': match_eqratio6_coll_ncoll_eqangle6, + 'eqangle6_coll_ncoll_eqratio6': match_eqangle6_coll_ncoll_eqratio6, + 'eqangle6_ncoll_cyclic': match_eqangle6_ncoll_cyclic, +} + + +SKIP_THEOREMS = set() + + +def set_skip_theorems(theorems: set[str]) -> None: + SKIP_THEOREMS.update(theorems) + + +MAX_BRANCH = 50_000 + + +def match_one_theorem( + g: gh.Graph, + cache: Callable[str, list[tuple[gm.Point, ...]]], + theorem: pr.Theorem +) -> Generator[dict[str, gm.Point], None, None]: + """Match all instances of a single theorem (rule).""" + if cache is None: + cache = cache_match(g) + + if theorem.name in SKIP_THEOREMS: + return [] + + if theorem.name.split('_')[-1] in SKIP_THEOREMS: + return [] + + if theorem.name in BUILT_IN_FNS: + mps = BUILT_IN_FNS[theorem.name](g, cache, theorem) + else: + mps = match_generic(g, cache, theorem) + + mappings = [] + for mp in mps: + mappings.append(mp) + if len(mappings) > MAX_BRANCH: # cap branching at this number. + break + + return mappings + + +def match_all_theorems( + g: gh.Graph, theorems: list[pr.Theorem], goal: pr.Clause +) -> dict[pr.Theorem, dict[pr.Theorem, dict[str, gm.Point]]]: + """Match all instances of all theorems (rules).""" + cache = cache_match(g) + # for BFS, collect all potential matches + # and then do it at the same time + theorem2mappings = {} + + # Step 1: list all matches + for _, theorem in theorems.items(): + name = theorem.name + if name.split('_')[-1] in [ + 'acompute', + 'rcompute', + 'fixl', + 'fixc', + 'fixb', + 'fixt', + 'fixp', + ]: + if goal and goal.name != name: + continue + + mappings = match_one_theorem(g, cache, theorem) + if len(mappings): # pylint: disable=g-explicit-length-test + theorem2mappings[theorem] = list(mappings) + return theorem2mappings + + +def bfs_one_level( + g: gh.Graph, + theorems: list[pr.Theorem], + level: int, + controller: pr.Problem, + verbose: bool = False, + nm_check: bool = False, + timeout: int = 600, +) -> tuple[ + list[pr.Dependency], + dict[str, list[tuple[gm.Point, ...]]], + dict[str, list[tuple[gm.Point, ...]]], + int, +]: + """Forward deduce one breadth-first level.""" + + # Step 1: match all theorems: + theorem2mappings = match_all_theorems(g, theorems, controller.goal) + + # Step 2: traceback for each deduce: + theorem2deps = {} + t0 = time.time() + for theorem, mappings in theorem2mappings.items(): + if time.time() - t0 > timeout: + break + mp_deps = [] + for mp in mappings: + deps = EmptyDependency(level=level, rule_name=theorem.rule_name) + fail = False # finding why deps might fail. + + for p in theorem.premise: + p_args = [mp[a] for a in p.args] + # Trivial deps. + if p.name == 'cong': + a, b, c, d = p_args + if {a, b} == {c, d}: + continue + if p.name == 'para': + a, b, c, d = p_args + if {a, b} == {c, d}: + continue + + if theorem.name in [ + 'cong_cong_eqangle6_ncoll_contri*', + 'eqratio6_eqangle6_ncoll_simtri*', + ]: + if p.name in ['eqangle', 'eqangle6']: # SAS or RAR + b, a, b, c, y, x, y, z = ( # pylint: disable=redeclared-assigned-name,unused-variable + p_args + ) + if not nm.same_clock(a.num, b.num, c.num, x.num, y.num, z.num): + p_args = b, a, b, c, y, z, y, x + + dep = Dependency(p.name, p_args, rule_name='', level=level) + try: + dep = dep.why_me_or_cache(g, level) + except: # pylint: disable=bare-except + fail = True + break + + if dep.why is None: + fail = True + break + g.cache_dep(p.name, p_args, dep) + deps.why.append(dep) + + if fail: + continue + + mp_deps.append((mp, deps)) + theorem2deps[theorem] = mp_deps + + theorem2deps = list(theorem2deps.items()) + + # Step 3: add conclusions to graph. + # Note that we do NOT mix step 2 and 3, strictly going for BFS. + added = [] + for theorem, mp_deps in theorem2deps: + for mp, deps in mp_deps: + if time.time() - t0 > timeout: + break + name, args = theorem.conclusion_name_args(mp) + hash_conclusion = pr.hashed(name, args) + if hash_conclusion in g.cache: + continue + + add = g.add_piece(name, args, deps=deps) + added += add + + branching = len(added) + + # Check if goal is found + if controller.goal: + args = [] + + for a in controller.goal.args: + if a in g._name2node: + a = g._name2node[a] + elif '/' in a: + a = create_consts_str(g, a) + elif a.isdigit(): + a = int(a) + args.append(a) + + if g.check(controller.goal.name, args): + return added, {}, {}, branching + + # Run AR, but do NOT apply to the proof state (yet). + for dep in added: + g.add_algebra(dep, level) + derives, eq4s = g.derive_algebra(level, verbose=False) + + branching += sum([len(x) for x in derives.values()]) + branching += sum([len(x) for x in eq4s.values()]) + + return added, derives, eq4s, branching + + +def create_consts_str(g: gh.Graph, s: str) -> gm.Angle | gm.Ratio: + if 'pi/' in s: + n, d = s.split('pi/') + n, d = int(n), int(d) + p0, _ = g.get_or_create_const_ang(n, d) + else: + n, d = s.split('/') + n, d = int(n), int(d) + p0, _ = g.get_or_create_const_rat(n, d) + return p0 + + +def do_algebra( + g: gh.Graph, added: list[pr.Dependency], verbose: bool = False +) -> None: + for add in added: + g.add_algebra(add, None) + derives, eq4s = g.derive_algebra(level=None, verbose=verbose) + apply_derivations(g, derives) + apply_derivations(g, eq4s) + + +def apply_derivations( + g: gh.Graph, derives: dict[str, list[tuple[gm.Point, ...]]] +) -> list[pr.Dependency]: + applied = [] + all_derives = list(derives.items()) + for name, args in all_derives: + for arg in args: + applied += g.do_algebra(name, arg) + return applied diff --git a/dd_test.py b/dd_test.py new file mode 100644 index 0000000..6cb2c40 --- /dev/null +++ b/dd_test.py @@ -0,0 +1,79 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for dd.""" +import unittest + +from absl.testing import absltest +import dd +import graph as gh +import problem as pr + + +MAX_LEVEL = 1000 + + +class DDTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) + cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) + + def test_imo_2022_p4_should_succeed(self): + p = pr.Problem.from_txt( + 'a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m =' + ' on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n' + ' g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b,' + ' on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a' + ' n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q' + ) + g, _ = gh.Graph.build_problem(p, DDTest.defs) + goal_args = g.names2nodes(p.goal.args) + + success = False + for level in range(MAX_LEVEL): + added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p) + if g.check(p.goal.name, goal_args): + success = True + break + if not added: # saturated + break + + self.assertTrue(success) + + def test_incenter_excenter_should_fail(self): + p = pr.Problem.from_txt( + 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?' + ' perp d c c e' + ) + g, _ = gh.Graph.build_problem(p, DDTest.defs) + goal_args = g.names2nodes(p.goal.args) + + success = False + for level in range(MAX_LEVEL): + added, _, _, _ = dd.bfs_one_level(g, DDTest.rules, level, p) + if g.check(p.goal.name, goal_args): + success = True + break + if not added: # saturated + break + + self.assertFalse(success) + + +if __name__ == '__main__': + absltest.main() diff --git a/ddar.py b/ddar.py new file mode 100644 index 0000000..8f910cb --- /dev/null +++ b/ddar.py @@ -0,0 +1,157 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements the combination DD+AR.""" +import time + +from absl import logging +import dd +import graph as gh +import problem as pr +from problem import Dependency # pylint: disable=g-importing-member +import trace_back + + +def saturate_or_goal( + g: gh.Graph, + theorems: list[pr.Theorem], + level_times: list[float], + p: pr.Problem, + max_level: int = 100, + timeout: int = 600, +) -> tuple[ + list[dict[str, list[tuple[gh.Point, ...]]]], + list[dict[str, list[tuple[gh.Point, ...]]]], + list[int], + list[pr.Dependency], +]: + """Run DD until saturation or goal found.""" + derives = [] + eq4s = [] + branching = [] + all_added = [] + + while len(level_times) < max_level: + level = len(level_times) + 1 + + t = time.time() + added, derv, eq4, n_branching = dd.bfs_one_level( + g, theorems, level, p, verbose=False, nm_check=True, timeout=timeout + ) + all_added += added + branching.append(n_branching) + + derives.append(derv) + eq4s.append(eq4) + level_time = time.time() - t + + logging.info(f'Depth {level}/{max_level} time = {level_time}') # pylint: disable=logging-fstring-interpolation + level_times.append(level_time) + + if p.goal is not None: + goal_args = list(map(lambda x: g.get(x, lambda: int(x)), p.goal.args)) + if g.check(p.goal.name, goal_args): # found goal + break + + if not added: # saturated + break + + if level_time > timeout: + break + + return derives, eq4s, branching, all_added + + +def solve( + g: gh.Graph, + theorems: list[pr.Problem], + controller: pr.Problem, + max_level: int = 1000, + timeout: int = 600, +) -> tuple[gh.Graph, list[float], str, list[int], list[pr.Dependency]]: + """Alternate between DD and AR until goal is found.""" + status = 'saturated' + level_times = [] + + dervs, eq4 = g.derive_algebra(level=0, verbose=False) + derives = [dervs] + eq4s = [eq4] + branches = [] + all_added = [] + + while len(level_times) < max_level: + dervs, eq4, next_branches, added = saturate_or_goal( + g, theorems, level_times, controller, max_level, timeout=timeout + ) + all_added += added + + derives += dervs + eq4s += eq4 + branches += next_branches + + # Now, it is either goal or saturated + if controller.goal is not None: + goal_args = g.names2points(controller.goal.args) + if g.check(controller.goal.name, goal_args): # found goal + status = 'solved' + break + + if not derives: # officially saturated. + break + + # Now we resort to algebra derivations. + added = [] + while derives and not added: + added += dd.apply_derivations(g, derives.pop(0)) + + if added: + continue + + # Final help from AR. + while eq4s and not added: + added += dd.apply_derivations(g, eq4s.pop(0)) + + all_added += added + + if not added: # Nothing left. saturated. + break + + return g, level_times, status, branches, all_added + + +def get_proof_steps( + g: gh.Graph, goal: pr.Clause, merge_trivials: bool = False +) -> tuple[ + list[pr.Dependency], + list[pr.Dependency], + list[tuple[list[pr.Dependency], list[pr.Dependency]]], + dict[tuple[str, ...], int], +]: + """Extract proof steps from the built DAG.""" + goal_args = g.names2nodes(goal.args) + query = Dependency(goal.name, goal_args, None, None) + + setup, aux, log, setup_points = trace_back.get_logs( + query, g, merge_trivials=merge_trivials + ) + + refs = {} + setup = trace_back.point_log(setup, refs, set()) + aux = trace_back.point_log(aux, refs, setup_points) + + setup = [(prems, [tuple(p)]) for p, prems in setup] + aux = [(prems, [tuple(p)]) for p, prems in aux] + + return setup, aux, log, refs diff --git a/ddar_test.py b/ddar_test.py new file mode 100644 index 0000000..7f68a4f --- /dev/null +++ b/ddar_test.py @@ -0,0 +1,65 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for ddar.py.""" +import unittest + +from absl.testing import absltest +import ddar +import graph as gh +import problem as pr + + +class DDARTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) + cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) + + def test_orthocenter_should_fail(self): + txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c' # pylint: disable=line-too-long + p = pr.Problem.from_txt(txt) + g, _ = gh.Graph.build_problem(p, DDARTest.defs) + + ddar.solve(g, DDARTest.rules, p, max_level=1000) + goal_args = g.names2nodes(p.goal.args) + self.assertFalse(g.check(p.goal.name, goal_args)) + + def test_orthocenter_aux_should_succeed(self): + txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long + p = pr.Problem.from_txt(txt) + g, _ = gh.Graph.build_problem(p, DDARTest.defs) + + ddar.solve(g, DDARTest.rules, p, max_level=1000) + goal_args = g.names2nodes(p.goal.args) + self.assertTrue(g.check(p.goal.name, goal_args)) + + def test_incenter_excenter_should_succeed(self): + # Note that this same problem should fail in dd_test.py + p = pr.Problem.from_txt( + 'a b c = triangle a b c; d = incenter d a b c; e = excenter e a b c ?' + ' perp d c c e' + ) # pylint: disable=line-too-long + g, _ = gh.Graph.build_problem(p, DDARTest.defs) + + ddar.solve(g, DDARTest.rules, p, max_level=1000) + goal_args = g.names2nodes(p.goal.args) + self.assertTrue(g.check(p.goal.name, goal_args)) + + +if __name__ == '__main__': + absltest.main() diff --git a/decoder_stack.py b/decoder_stack.py new file mode 100644 index 0000000..e2098b6 --- /dev/null +++ b/decoder_stack.py @@ -0,0 +1,55 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The decoder stack in inference mode.""" + +from typing import Any, Tuple + +import gin +from transformer import decoder_stack +import transformer_layer as tl + + +struct = decoder_stack.struct +nn_components = decoder_stack.nn_components +position = decoder_stack.position +jnp = decoder_stack.jnp +attention = decoder_stack.attention + +DStackWindowState = decoder_stack.DStackWindowState + +Array = Any + +TransformerTaskConfig = decoder_stack.TransformerTaskConfig + +DStackDecoderState = Tuple[tl.DecoderState, ...] + + +@gin.configurable +class DecoderStackGenerate(decoder_stack.DecoderStack): + """Stack of transformer decoder layers.""" + + layer_factory = tl.TransformerLayerGenerate + + def init_decoder_state_vanilla( + self, sequence_length: int, start_of_sequence: Array + ) -> DStackDecoderState: + """Return initial state for autoregressive generation.""" + return tuple( + [ + layer.init_decoder_state_vanilla(sequence_length, start_of_sequence) + for layer in self.transformer_layers + ] + ) diff --git a/defs.txt b/defs.txt new file mode 100644 index 0000000..ed87ef1 --- /dev/null +++ b/defs.txt @@ -0,0 +1,407 @@ +angle_bisector x a b c +x : a b c x +a b c = ncoll a b c +x : eqangle b a b x b x b c +bisect a b c + +angle_mirror x a b c +x : a b c x +a b c = ncoll a b c +x : eqangle b a b c b c b x +amirror a b c + +circle x a b c +x : a b c +a b c = ncoll a b c +x : cong x a x b, cong x b x c +bline a b, bline a c + +circumcenter x a b c +x : a b c +a b c = ncoll a b c +x : cong x a x b, cong x b x c +bline a b, bline a c + +eq_quadrangle a b c d +d : a b c d + = +a : ; b : ; c : ; d : cong d a b c +eq_quadrangle + +eq_trapezoid a b c d +d : a b c + = +a : ; b : ; c : ; d : para d c a b, cong d a b c +eq_trapezoid + +eq_triangle x b c +x : b c +b c = diff b c +x : cong x b b c, cong b c c x; eqangle b x b c c b c x, eqangle x c x b b x b c +circle b b c, circle c b c + +eqangle2 x a b c +x : a b c x +a b c = ncoll a b c +x : eqangle a b a x c x c b +eqangle2 a b c + +eqdia_quadrangle a b c d +d : a b c d + = +a : ; b : ; c : ; d : cong d b a c +eqdia_quadrangle + +eqdistance x a b c +x : a b c x +a b c = diff b c +x : cong x a b c +circle a b c + +foot x a b c +x : a b c +a b c = ncoll a b c +x : perp x a b c, coll x b c +tline a b c, line b c + +free a +a : a + = +a : +free + +incenter x a b c +x : a b c +a b c = ncoll a b c +x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a +bisect a b c, bisect b c a + +incenter2 x y z i a b c +i : a b c, x : i b c, y : i c a, z : i a b +a b c = ncoll a b c +i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z +incenter2 a b c + +excenter x a b c +x : a b c +a b c = ncoll a b c +x : eqangle a b a x a x a c, eqangle c a c x c x c b; eqangle b c b x b x b a +bisect b a c, exbisect b c a + +excenter2 x y z i a b c +i : a b c, x : i b c, y : i c a, z : i a b +a b c = ncoll a b c +i : eqangle a b a i a i a c, eqangle c a c i c i c b; eqangle b c b i b i b a; x : coll x b c, perp i x b c; y : coll y c a, perp i y c a; z : coll z a b, perp i z a b; cong i x i y, cong i y i z +excenter2 a b c + +centroid x y z i a b c +x : b c, y : c a, z : a b, i : a x b y +a b c = ncoll a b c +x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : coll a x i, coll b y i; coll c z i +centroid a b c + +ninepoints x y z i a b c +x : b c, y : c a, z : a b, i : x y z +a b c = ncoll a b c +x : coll x b c, cong x b x c; y : coll y c a, cong y c y a; z : coll z a b, cong z a z b; i : cong i x i y, cong i y i z +ninepoints a b c + +intersection_cc x o w a +x : o w a +o w a = ncoll o w a +x : cong o a o x, cong w a w x +circle o o a, circle w w a + +intersection_lc x a o b +x : a o b +a o b = diff a b, diff o b, nperp b o b a +x : coll x a b, cong o b o x +line b a, circle o o b + +intersection_ll x a b c d +x : a b c d +a b c d = npara a b c d, ncoll a b c d +x : coll x a b, coll x c d +line a b, line c d + +intersection_lp x a b c m n +x : a b c m n +a b c m n = npara m n a b, ncoll a b c, ncoll c m n +x : coll x a b, para c x m n +line a b, pline c m n + +intersection_lt x a b c d e +x : a b c d e +a b c d e = ncoll a b c, nperp a b d e +x : coll x a b, perp x c d e +line a b, tline c d e + +intersection_pp x a b c d e f +x : a b c d e f +a b c d e f = diff a d, npara b c e f +x : para x a b c, para x d e f +pline a b c, pline d e f + +intersection_tt x a b c d e f +x : a b c d e f +a b c d e f = diff a d, npara b c e f +x : perp x a b c, perp x d e f +tline a b c, tline d e f + +iso_triangle a b c +c : a b c + = +a : ; b : ; c : eqangle b a b c c b c a, cong a b a c +isos + +lc_tangent x a o +x : x a o +a o = diff a o +x : perp a x a o +tline a a o + +midpoint x a b +x : a b +a b = diff a b +x : coll x a b, cong x a x b +midp a b + +mirror x a b +x : a b +a b = diff a b +x : coll x a b, cong b a b x +pmirror a b + +nsquare x a b +x : a b +a b = diff a b +x : cong x a a b, perp x a a b +rotaten90 a b + +on_aline x a b c d e +x : x a b c d e +a b c d e = ncoll c d e +x : eqangle a x a b d c d e +aline e d c b a + +on_aline2 x a b c d e +x : x a b c d e +a b c d e = ncoll c d e +x : eqangle x a x b d c d e +aline2 e d c b a + +on_bline x a b +x : x a b +a b = diff a b +x : cong x a x b, eqangle a x a b b a b x +bline a b + +on_circle x o a +x : x o a +o a = diff o a +x : cong o x o a +circle o o a + +on_line x a b +x : x a b +a b = diff a b +x : coll x a b +line a b + +on_pline x a b c +x : x a b c +a b c = diff b c, ncoll a b c +x : para x a b c +pline a b c + +on_tline x a b c +x : x a b c +a b c = diff b c +x : perp x a b c +tline a b c + +orthocenter x a b c +x : a b c +a b c = ncoll a b c +x : perp x a b c, perp x b c a; perp x c a b +tline a b c, tline b c a + +parallelogram a b c x +x : a b c +a b c = ncoll a b c +x : para a b c x, para a x b c; cong a b c x, cong a x b c +pline a b c, pline c a b + +pentagon a b c d e + + = +a : ; b : ; c : ; d : ; e : +pentagon + +psquare x a b +x : a b +a b = diff a b +x : cong x a a b, perp x a a b +rotatep90 a b + +quadrangle a b c d + + = +a : ; b : ; c : ; d : +quadrangle + +r_trapezoid a b c d +d : a b c + = +a : ; b : ; c : ; d : para a b c d, perp a b a d +r_trapezoid + +r_triangle a b c +c : a b c + = +a : ; b : ; c : perp a b a c +r_triangle + +rectangle a b c d +c : a b c , d : a b c + = +a : ; b : ; c : perp a b b c ; d : para a b c d, para a d b c; perp a b a d, cong a b c d, cong a d b c, cong a c b d +rectangle + +reflect x a b c +x : a b c +a b c = diff b c, ncoll a b c +x : cong b a b x, cong c a c x; perp b c a x +reflect a b c + +risos a b c +c : a b + = +a : ; b : ; c : perp a b a c, cong a b a c; eqangle b a b c c b c a +risos + +s_angle a b x y +x : a b x +a b = diff a b +x : s_angle a b x y +s_angle a b y + +segment a b + + = +a : ; b : +segment + +shift x b c d +x : b c d +b c d = diff d b +x : cong x b c d, cong x c b d +shift d c b + +square a b x y +x : a b, y : a b x +a b = diff a b +x : perp a b b x, cong a b b x; y : para a b x y, para a y b x; perp a y y x, cong b x x y, cong x y y a, perp a x b y, cong a x b y +square a b + +isquare a b c d +c : a b , d : a b c + = +a : ; b : ; c : perp a b b c, cong a b b c; d : para a b c d, para a d b c; perp a d d c, cong b c c d, cong c d d a, perp a c b d, cong a c b d +isquare + +trapezoid a b c d +d : a b c d + = +a : ; b : ; c : ; d : para a b c d +trapezoid + +triangle a b c + + = +a : ; b : ; c : +triangle + +triangle12 a b c +c : a b c + = +a : ; b : ; c : rconst a b a c 1 2 +triangle12 + +2l1c x y z i a b c o +x : a b c o y z i, y : a b c o x z i, z : a b c o x y i, i : a b c o x y z +a b c o = cong o a o b, ncoll a b c +x y z i : coll x a c, coll y b c, cong o a o z, coll i o z, cong i x i y, cong i y i z, perp i x a c, perp i y b c +2l1c a b c o + +e5128 x y a b c d +x : a b c d y, y : a b c d x +a b c d = cong c b c d, perp b c b a +x y : cong c b c x, coll y a b, coll x y d, eqangle a b a d x a x y +e5128 a b c d + +3peq x y z a b c +z : b c z , x : a b c z y, y : a b c z x +a b c = ncoll a b c +z : coll z b c ; x y : coll x a b, coll y a c, coll x y z, cong z x z y +3peq a b c + +trisect x y a b c +x : a b c y, y : a b c x +a b c = ncoll a b c +x y : coll x a c, coll y a c, eqangle b a b x b x b y, eqangle b x b y b y b c +trisect a b c + +trisegment x y a b +x : a b y, y : a b x +a b = diff a b +x y : coll x a b, coll y a b, cong x a x y, cong y x y b +trisegment a b + +on_dia x a b +x : x a b +a b = diff a b +x : perp x a x b +dia a b + +ieq_triangle a b c +c : a b + = +a : ; b : ; c : cong a b b c, cong b c c a; eqangle a b a c c a c b, eqangle c a c b b c b a +ieq_triangle + +on_opline x a b +x : x a b +a b = diff a b +x : coll x a b +on_opline a b + +cc_tangent0 x y o a w b +x : o a w b y, y : o a w b x +o a w b = diff o a, diff w b, diff o w +x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x +cc_tangent0 o a w b + +cc_tangent x y z i o a w b +x : o a w b y, y : o a w b x, z : o a w b i, i : o a w b z +o a w b = diff o a, diff w b, diff o w +x y : cong o x o a, cong w y w b, perp x o x y, perp y w y x; z i : cong o z o a, cong w i w b, perp z o z i, perp i w i z +cc_tangent o a w b + +eqangle3 x a b d e f +x : x a b d e f +a b d e f = ncoll d e f, diff a b, diff d e, diff e f +x : eqangle x a x b d e d f +eqangle3 a b d e f + +tangent x y a o b +x y : o a b +a o b = diff o a, diff o b, diff a b +x : cong o x o b, perp a x o x; y : cong o y o b, perp a y o y +tangent a o b + +on_circum x a b c +x : a b c +a b c = ncoll a b c +x : cyclic a b c x +cyclic a b c diff --git a/download.sh b/download.sh new file mode 100644 index 0000000..a73e35e --- /dev/null +++ b/download.sh @@ -0,0 +1,17 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +gdown --folder https://bit.ly/g0-ckpt-vocab +export DATA=g0_ckpt_vocab diff --git a/examples.txt b/examples.txt new file mode 100644 index 0000000..81c9b71 --- /dev/null +++ b/examples.txt @@ -0,0 +1,8 @@ +orthocenter +a b c = triangle; h = on_tline b a c, on_tline c a b ? perp a h b c +orthocenter_aux +a b c = triangle; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c +incenter_excenter +a b c = triangle a b c; d1 d2 d3 d = incenter2 a b c; e1 e2 e3 e = excenter2 a b c ? perp d c c e +euler +a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o diff --git a/fig1.svg b/fig1.svg new file mode 100644 index 0000000..72e6666 --- /dev/null +++ b/fig1.svg @@ -0,0 +1 @@ + diff --git a/geometry.py b/geometry.py new file mode 100644 index 0000000..ba9463f --- /dev/null +++ b/geometry.py @@ -0,0 +1,578 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements geometric objects used in the graph representation.""" +from __future__ import annotations +from collections import defaultdict # pylint: disable=g-importing-member +from typing import Any, Type + +# pylint: disable=protected-access + + +class Node: + r"""Node in the proof state graph. + + Can be Point, Line, Circle, etc. + + Each node maintains a merge history to + other nodes if they are (found out to be) equivalent + + a -> b - + \ + c -> d -> e -> f -> g + + d.merged_to = e + d.rep = g + d.merged_from = {a, b, c, d} + d.equivs = {a, b, c, d, e, f, g} + """ + + def __init__(self, name: str = '', graph: Any = None): + self.name = name or str(self) + self.graph = graph + + self.edge_graph = {} + # Edge graph: what other nodes is connected to this node. + # edge graph = { + # other1: {self1: deps, self2: deps}, + # other2: {self2: deps, self3: deps} + # } + + self.merge_graph = {} + # Merge graph: history of merges with other nodes. + # merge_graph = {self1: {self2: deps1, self3: deps2}} + + self.rep_by = None # represented by. + self.members = {self} + + self._val = None + self._obj = None + + self.deps = [] + + # numerical representation. + self.num = None + self.change = set() # what other nodes' num rely on this node? + + def set_rep(self, node: Node) -> None: + if node == self: + return + self.rep_by = node + node.merge_edge_graph(self.edge_graph) + node.members.update(self.members) + + def rep(self) -> Node: + x = self + while x.rep_by: + x = x.rep_by + return x + + def why_rep(self) -> list[Any]: + return self.why_equal([self.rep()], None) + + def rep_and_why(self) -> tuple[Node, list[Any]]: + rep = self.rep() + return rep, self.why_equal([rep], None) + + def neighbors( + self, oftype: Type[Node], return_set: bool = False, do_rep: bool = True + ) -> list[Node]: + """Neighbors of this node in the proof state graph.""" + if do_rep: + rep = self.rep() + else: + rep = self + result = set() + + for n in rep.edge_graph: + if oftype is None or oftype and isinstance(n, oftype): + if do_rep: + result.add(n.rep()) + else: + result.add(n) + + if return_set: + return result + return list(result) + + def merge_edge_graph( + self, new_edge_graph: dict[Node, dict[Node, list[Node]]] + ) -> None: + for x, xdict in new_edge_graph.items(): + if x in self.edge_graph: + self.edge_graph[x].update(dict(xdict)) + else: + self.edge_graph[x] = dict(xdict) + + def merge(self, nodes: list[Node], deps: list[Any]) -> None: + for node in nodes: + self.merge_one(node, deps) + + def merge_one(self, node: Node, deps: list[Any]) -> None: + node.rep().set_rep(self.rep()) + + if node in self.merge_graph: + return + + self.merge_graph[node] = deps + node.merge_graph[self] = deps + + def is_val(self, node: Node) -> bool: + return ( + isinstance(self, Line) + and isinstance(node, Direction) + or isinstance(self, Segment) + and isinstance(node, Length) + or isinstance(self, Angle) + and isinstance(node, Measure) + or isinstance(self, Ratio) + and isinstance(node, Value) + ) + + def set_val(self, node: Node) -> None: + self._val = node + + def set_obj(self, node: Node) -> None: + self._obj = node + + @property + def val(self) -> Node: + if self._val is None: + return None + return self._val.rep() + + @property + def obj(self) -> Node: + if self._obj is None: + return None + return self._obj.rep() + + def equivs(self) -> set[Node]: + return self.rep().members + + def connect_to(self, node: Node, deps: list[Any] = None) -> None: + rep = self.rep() + + if node in rep.edge_graph: + rep.edge_graph[node].update({self: deps}) + else: + rep.edge_graph[node] = {self: deps} + + if self.is_val(node): + self.set_val(node) + node.set_obj(self) + + def equivs_upto(self, level: int) -> dict[Node, Node]: + """What are the equivalent nodes up to a certain level.""" + parent = {self: None} + visited = set() + queue = [self] + i = 0 + + while i < len(queue): + current = queue[i] + i += 1 + visited.add(current) + + for neighbor in current.merge_graph: + if ( + level is not None + and current.merge_graph[neighbor].level is not None + and current.merge_graph[neighbor].level >= level + ): + continue + if neighbor not in visited: + queue.append(neighbor) + parent[neighbor] = current + + return parent + + def why_equal(self, others: list[Node], level: int) -> list[Any]: + """BFS why this node is equal to other nodes.""" + others = set(others) + found = 0 + + parent = {} + queue = [self] + i = 0 + + while i < len(queue): + current = queue[i] + if current in others: + found += 1 + if found == len(others): + break + + i += 1 + + for neighbor in current.merge_graph: + if ( + level is not None + and current.merge_graph[neighbor].level is not None + and current.merge_graph[neighbor].level >= level + ): + continue + if neighbor not in parent: + queue.append(neighbor) + parent[neighbor] = current + + return bfs_backtrack(self, others, parent) + + def why_equal_groups( + self, groups: list[list[Node]], level: int + ) -> tuple[list[Any], list[Node]]: + """BFS for why self is equal to at least one member of each group.""" + others = [None for _ in groups] + found = 0 + + parent = {} + queue = [self] + i = 0 + + while i < len(queue): + current = queue[i] + + for j, grp in enumerate(groups): + if others[j] is None and current in grp: + others[j] = current + found += 1 + + if found == len(others): + break + + i += 1 + + for neighbor in current.merge_graph: + if ( + level is not None + and current.merge_graph[neighbor].level is not None + and current.merge_graph[neighbor].level >= level + ): + continue + if neighbor not in parent: + queue.append(neighbor) + parent[neighbor] = current + + return bfs_backtrack(self, others, parent), others + + def why_val(self, level: int) -> list[Any]: + return self._val.why_equal([self.val], level) + + def why_connect(self, node: Node, level: int = None) -> list[Any]: + rep = self.rep() + equivs = list(rep.edge_graph[node].keys()) + if not equivs: + return None + equiv = equivs[0] + dep = rep.edge_graph[node][equiv] + return [dep] + self.why_equal(equiv, level) + + +def why_connect(*pairs: list[tuple[Node, Node]]) -> list[Any]: + result = [] + for node1, node2 in pairs: + result += node1.why_connect(node2) + return result + + +def is_equiv(x: Node, y: Node, level: int = None) -> bool: + level = level or float('inf') + return x.why_equal([y], level) is not None + + +def is_equal(x: Node, y: Node, level: int = None) -> bool: + if x == y: + return True + if x._val is None or y._val is None: + return False + if x.val != y.val: + return False + return is_equiv(x._val, y._val, level) + + +def bfs_backtrack( + root: Node, leafs: list[Node], parent: dict[Node, Node] +) -> list[Any]: + """Return the path given BFS trace of parent nodes.""" + backtracked = {root} # no need to backtrack further when touching this set. + deps = [] + for node in leafs: + if node is None: + return None + if node in backtracked: + continue + if node not in parent: + return None + while node not in backtracked: + backtracked.add(node) + deps.append(node.merge_graph[parent[node]]) + node = parent[node] + + return deps + + +class Point(Node): + pass + + +class Line(Node): + """Node of type Line.""" + + def new_val(self) -> Direction: + return Direction() + + def why_coll(self, points: list[Point], level: int = None) -> list[Any]: + """Why points are connected to self.""" + level = level or float('inf') + + groups = [] + for p in points: + group = [ + l + for l, d in self.edge_graph[p].items() + if d is None or d.level < level + ] + if not group: + return None + groups.append(group) + + min_deps = None + for line in groups[0]: + deps, others = line.why_equal_groups(groups[1:], level) + if deps is None: + continue + for p, o in zip(points, [line] + others): + deps.append(self.edge_graph[p][o]) + if min_deps is None or len(deps) < len(min_deps): + min_deps = deps + + if min_deps is None: + return None + return [d for d in min_deps if d is not None] + + +class Segment(Node): + + def new_val(self) -> Length: + return Length() + + +class Circle(Node): + """Node of type Circle.""" + + def why_cyclic(self, points: list[Point], level: int = None) -> list[Any]: + """Why points are connected to self.""" + level = level or float('inf') + + groups = [] + for p in points: + group = [ + c + for c, d in self.edge_graph[p].items() + if d is None or d.level < level + ] + if not group: + return None + groups.append(group) + + min_deps = None + for circle in groups[0]: + deps, others = circle.why_equal_groups(groups[1:], level) + if deps is None: + continue + for p, o in zip(points, [circle] + others): + deps.append(self.edge_graph[p][o]) + + if min_deps is None or len(deps) < len(min_deps): + min_deps = deps + + if min_deps is None: + return None + return [d for d in min_deps if d is not None] + + +def why_equal(x: Node, y: Node, level: int = None) -> list[Any]: + if x == y: + return [] + if not x._val or not y._val: + return None + if x._val == y._val: + return [] + return x._val.why_equal([y._val], level) + + +class Direction(Node): + pass + + +def get_lines_thru_all(*points: list[Point]) -> list[Line]: + line2count = defaultdict(lambda: 0) + points = set(points) + for p in points: + for l in p.neighbors(Line): + line2count[l] += 1 + return [l for l, count in line2count.items() if count == len(points)] + + +def line_of_and_why( + points: list[Point], level: int = None +) -> tuple[Line, list[Any]]: + """Why points are collinear.""" + for l0 in get_lines_thru_all(*points): + for l in l0.equivs(): + if all([p in l.edge_graph for p in points]): + x, y = l.points + colls = list({x, y} | set(points)) + # if len(colls) < 3: + # return l, [] + why = l.why_coll(colls, level) + if why is not None: + return l, why + + return None, None + + +def get_circles_thru_all(*points: list[Point]) -> list[Circle]: + circle2count = defaultdict(lambda: 0) + points = set(points) + for p in points: + for c in p.neighbors(Circle): + circle2count[c] += 1 + return [c for c, count in circle2count.items() if count == len(points)] + + +def circle_of_and_why( + points: list[Point], level: int = None +) -> tuple[Circle, list[Any]]: + """Why points are concyclic.""" + for c0 in get_circles_thru_all(*points): + for c in c0.equivs(): + if all([p in c.edge_graph for p in points]): + cycls = list(set(points)) + why = c.why_cyclic(cycls, level) + if why is not None: + return c, why + + return None, None + + +def name_map(struct: Any) -> Any: + if isinstance(struct, list): + return [name_map(x) for x in struct] + elif isinstance(struct, tuple): + return tuple([name_map(x) for x in struct]) + elif isinstance(struct, set): + return set([name_map(x) for x in struct]) + elif isinstance(struct, dict): + return {name_map(x): name_map(y) for x, y in struct.items()} + else: + return getattr(struct, 'name', '') + + +class Angle(Node): + """Node of type Angle.""" + + def new_val(self) -> Measure: + return Measure() + + def set_directions(self, d1: Direction, d2: Direction) -> None: + self._d = d1, d2 + + @property + def directions(self) -> tuple[Direction, Direction]: + d1, d2 = self._d + if d1 is None or d2 is None: + return d1, d2 + return d1.rep(), d2.rep() + + +class Measure(Node): + pass + + +class Length(Node): + pass + + +class Ratio(Node): + """Node of type Ratio.""" + + def new_val(self) -> Value: + return Value() + + def set_lengths(self, l1: Length, l2: Length) -> None: + self._l = l1, l2 + + @property + def lengths(self) -> tuple[Length, Length]: + l1, l2 = self._l + if l1 is None or l2 is None: + return l1, l2 + return l1.rep(), l2.rep() + + +class Value(Node): + pass + + +def all_angles( + d1: Direction, d2: Direction, level: int = None +) -> tuple[Angle, list[Direction], list[Direction]]: + level = level or float('inf') + d1s = d1.equivs_upto(level) + d2s = d2.equivs_upto(level) + + for ang in d1.rep().neighbors(Angle): + d1_, d2_ = ang._d + if d1_ in d1s and d2_ in d2s: + yield ang, d1s, d2s + + +def all_ratios( + d1, d2, level=None +) -> tuple[Angle, list[Direction], list[Direction]]: + level = level or float('inf') + d1s = d1.equivs_upto(level) + d2s = d2.equivs_upto(level) + + for ang in d1.rep().neighbors(Ratio): + d1_, d2_ = ang._l + if d1_ in d1s and d2_ in d2s: + yield ang, d1s, d2s + + +RANKING = { + Point: 0, + Line: 1, + Segment: 2, + Circle: 3, + Direction: 4, + Length: 5, + Angle: 6, + Ratio: 7, + Measure: 8, + Value: 9, +} + + +def val_type(x: Node) -> Type[Node]: + if isinstance(x, Line): + return Direction + if isinstance(x, Segment): + return Length + if isinstance(x, Angle): + return Measure + if isinstance(x, Ratio): + return Value diff --git a/geometry_150M_generate.gin b/geometry_150M_generate.gin new file mode 100644 index 0000000..5a9bce4 --- /dev/null +++ b/geometry_150M_generate.gin @@ -0,0 +1,47 @@ +NUM_EMBEDDINGS = 1024 + +# Number of parameters = 152M +NUM_LAYERS = 12 +EMBED_DIM = 1024 +NUM_HEADS = 8 +HEAD_DIM = 128 +MLP_DIM = 4096 + + +transformer_layer.TransformerLayerGenerate: + num_heads = %NUM_HEADS + head_size = %HEAD_DIM + window_length = 1024 + use_long_xl_architecture = False + max_unrolled_windows = -1 # Always unroll. + relative_position_type = "t5" # Can be "fourier", "t5", or None. + use_causal_mask = True + attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout. + memory_num_neighbors = 0 + dtype = %DTYPE + +decoder_stack.DecoderStackGenerate: + num_layers = %NUM_LAYERS + embedding_size = %EMBED_DIM + embedding_stddev = 1.0 + layer_factory = @transformer_layer.TransformerLayerGenerate + dstack_window_length = 0 + use_absolute_positions = False + use_final_layernorm = True # Final layernorm before token lookup. + final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup. + final_mlp_factory = None # Final MLP to predict target tokens. + recurrent_layer_indices = () + memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory + memory_layer_indices = () + dtype = %DTYPE + + +models.DecoderOnlyLanguageModelGenerate: + num_heads = %NUM_HEADS + head_size = %HEAD_DIM + task_config = @decoder_stack.TransformerTaskConfig() + decoder_factory = @decoder_stack.DecoderStackGenerate + + +training_loop.Trainer: + model_definition = @models.DecoderOnlyLanguageModelGenerate diff --git a/geometry_test.py b/geometry_test.py new file mode 100644 index 0000000..d711996 --- /dev/null +++ b/geometry_test.py @@ -0,0 +1,80 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for geometry.py.""" +import unittest + +from absl.testing import absltest +import geometry as gm + + +class GeometryTest(unittest.TestCase): + + def _setup_equality_example(self): + # Create 4 nodes a, b, c, d + # and their lengths + a = gm.Segment('a') + la = gm.Length('l(a)') + a.connect_to(la) + la.connect_to(a) + + b = gm.Segment('b') + lb = gm.Length('l(b)') + b.connect_to(lb) + lb.connect_to(b) + + c = gm.Segment('c') + lc = gm.Length('l(c)') + c.connect_to(lc) + lc.connect_to(c) + + d = gm.Segment('d') + ld = gm.Length('l(d)') + d.connect_to(ld) + ld.connect_to(d) + + # Now let a=b, b=c, a=c, c=d + la.merge([lb], 'fact1') + lb.merge([lc], 'fact2') + la.merge([lc], 'fact3') + lc.merge([ld], 'fact4') + return a, b, c, d, la, lb, lc, ld + + def test_merged_node_representative(self): + _, _, _, _, la, lb, lc, ld = self._setup_equality_example() + + # all nodes are now represented by la. + self.assertEqual(la.rep(), la) + self.assertEqual(lb.rep(), la) + self.assertEqual(lc.rep(), la) + self.assertEqual(ld.rep(), la) + + def test_merged_node_equivalence(self): + _, _, _, _, la, lb, lc, ld = self._setup_equality_example() + # all la, lb, lc, ld are equivalent + self.assertCountEqual(la.equivs(), [la, lb, lc, ld]) + self.assertCountEqual(lb.equivs(), [la, lb, lc, ld]) + self.assertCountEqual(lc.equivs(), [la, lb, lc, ld]) + self.assertCountEqual(ld.equivs(), [la, lb, lc, ld]) + + def test_bfs_for_equality_transitivity(self): + a, _, _, d, _, _, _, _ = self._setup_equality_example() + + # check that a==d because fact3 & fact4, not fact1 & fact2 + self.assertCountEqual(gm.why_equal(a, d), ['fact3', 'fact4']) + + +if __name__ == '__main__': + absltest.main() diff --git a/graph.py b/graph.py new file mode 100644 index 0000000..ddadb4a --- /dev/null +++ b/graph.py @@ -0,0 +1,3057 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements the graph representation of the proof state.""" + +# pylint: disable=g-multiple-import +from __future__ import annotations + +from collections import defaultdict # pylint: disable=g-importing-member +from typing import Callable, Generator, Optional, Type, Union + +from absl import logging +import ar +import geometry as gm +from geometry import Angle, Direction, Length, Ratio +from geometry import Circle, Line, Point, Segment +from geometry import Measure, Value +import graph_utils as utils +import numericals as nm +import problem +from problem import Dependency, EmptyDependency + + +np = nm.np + + +FREE = [ + 'free', + 'segment', + 'r_triangle', + 'risos', + 'triangle', + 'triangle12', + 'ieq_triangle', + 'eq_quadrangle', + 'eq_trapezoid', + 'eqdia_quadrangle', + 'quadrangle', + 'r_trapezoid', + 'rectangle', + 'isquare', + 'trapezoid', + 'pentagon', + 'iso_triangle', +] + +INTERSECT = [ + 'angle_bisector', + 'angle_mirror', + 'eqdistance', + 'lc_tangent', + 'on_aline', + 'on_bline', + 'on_circle', + 'on_line', + 'on_pline', + 'on_tline', + 'on_dia', + 's_angle', + 'on_opline', + 'eqangle3', +] + + +# pylint: disable=protected-access +# pylint: disable=unused-argument + + +class DepCheckFailError(Exception): + pass + + +class PointTooCloseError(Exception): + pass + + +class PointTooFarError(Exception): + pass + + +class Graph: + """Graph data structure representing proof state.""" + + def __init__(self): + self.type2nodes = { + Point: [], + Line: [], + Segment: [], + Circle: [], + Direction: [], + Length: [], + Angle: [], + Ratio: [], + Measure: [], + Value: [], + } + self._name2point = {} + self._name2node = {} + + self.rconst = {} # contains all constant ratios + self.aconst = {} # contains all constant angles. + + self.halfpi, _ = self.get_or_create_const_ang(1, 2) + self.vhalfpi = self.halfpi.val + + self.atable = ar.AngleTable() + self.dtable = ar.DistanceTable() + self.rtable = ar.RatioTable() + + # to quick access deps. + self.cache = {} + + self._pair2line = {} + self._triplet2circle = {} + + def copy(self) -> Graph: + """Make a copy of self.""" + p, definitions = self.build_def + + p = p.copy() + for clause in p.clauses: + clause.nums = [] + for pname in clause.points: + clause.nums.append(self._name2node[pname].num) + + g, _ = Graph.build_problem(p, definitions, verbose=False, init_copy=False) + + g.build_clauses = list(getattr(self, 'build_clauses', [])) + return g + + def _create_const_ang(self, n: int, d: int) -> None: + n, d = ar.simplify(n, d) + ang = self.aconst[(n, d)] = self.new_node(Angle, f'{n}pi/{d}') + ang.set_directions(None, None) + self.connect_val(ang, deps=None) + + def _create_const_rat(self, n: int, d: int) -> None: + n, d = ar.simplify(n, d) + rat = self.rconst[(n, d)] = self.new_node(Ratio, f'{n}/{d}') + rat.set_lengths(None, None) + self.connect_val(rat, deps=None) + + def get_or_create_const_ang(self, n: int, d: int) -> None: + n, d = ar.simplify(n, d) + if (n, d) not in self.aconst: + self._create_const_ang(n, d) + ang1 = self.aconst[(n, d)] + + n, d = ar.simplify(d - n, d) + if (n, d) not in self.aconst: + self._create_const_ang(n, d) + ang2 = self.aconst[(n, d)] + return ang1, ang2 + + def get_or_create_const_rat(self, n: int, d: int) -> None: + n, d = ar.simplify(n, d) + if (n, d) not in self.rconst: + self._create_const_rat(n, d) + rat1 = self.rconst[(n, d)] + + if (d, n) not in self.rconst: + self._create_const_rat(d, n) # pylint: disable=arguments-out-of-order + rat2 = self.rconst[(d, n)] + return rat1, rat2 + + def add_algebra(self, dep: Dependency, level: int) -> None: + """Add new algebraic predicates.""" + _ = level + if dep.name not in [ + 'para', + 'perp', + 'eqangle', + 'eqratio', + 'aconst', + 'rconst', + 'cong', + ]: + return + + name, args = dep.name, dep.args + + if name == 'para': + ab, cd = dep.algebra + self.atable.add_para(ab, cd, dep) + + if name == 'perp': + ab, cd = dep.algebra + self.atable.add_const_angle(ab, cd, 90, dep) + + if name == 'eqangle': + ab, cd, mn, pq = dep.algebra + if (ab, cd) == (pq, mn): + self.atable.add_const_angle(ab, cd, 90, dep) + else: + self.atable.add_eqangle(ab, cd, mn, pq, dep) + + if name == 'eqratio': + ab, cd, mn, pq = dep.algebra + if (ab, cd) == (pq, mn): + self.rtable.add_eq(ab, cd, dep) + else: + self.rtable.add_eqratio(ab, cd, mn, pq, dep) + + if name == 'aconst': + bx, ab, y = dep.algebra + self.atable.add_const_angle(bx, ab, y, dep) + + if name == 'rconst': + l1, l2, m, n = dep.algebra + self.rtable.add_const_ratio(l1, l2, m, n, dep) + + if name == 'cong': + a, b, c, d = args + ab, _ = self.get_line_thru_pair_why(a, b) + cd, _ = self.get_line_thru_pair_why(c, d) + self.dtable.add_cong(ab, cd, a, b, c, d, dep) + + ab, cd = dep.algebra + self.rtable.add_eq(ab, cd, dep) + + def add_eqrat_const( + self, args: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add new algebraic predicates of type eqratio-constant.""" + a, b, c, d, num, den = args + nd, dn = self.get_or_create_const_rat(num, den) + + if num == den: + return self.add_cong([a, b, c, d], deps) + + ab = self._get_or_create_segment(a, b, deps=None) + cd = self._get_or_create_segment(c, d, deps=None) + + self.connect_val(ab, deps=None) + self.connect_val(cd, deps=None) + + if ab.val == cd.val: + raise ValueError(f'{ab.name} and {cd.name} cannot be equal') + + args = [a, b, c, d, nd] + i = 0 + for x, y, xy in [(a, b, ab), (c, d, cd)]: + i += 1 + x_, y_ = list(xy._val._obj.points) + if {x, y} == {x_, y_}: + continue + if deps: + deps = deps.extend(self, 'rconst', list(args), 'cong', [x, y, x_, y_]) + args[2 * i - 2] = x_ + args[2 * i - 1] = y_ + + ab_cd, cd_ab, why = self._get_or_create_ratio(ab, cd, deps=None) + if why: + dep0 = deps.populate('rconst', [a, b, c, d, nd]) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why + + lab, lcd = ab_cd._l + a, b = list(lab._obj.points) + c, d = list(lcd._obj.points) + + add = [] + if not self.is_equal(ab_cd, nd): + args = [a, b, c, d, nd] + dep1 = deps.populate('rconst', args) + dep1.algebra = ab._val, cd._val, num, den + self.make_equal(nd, ab_cd, deps=dep1) + self.cache_dep('rconst', [a, b, c, d, nd], dep1) + add += [dep1] + + if not self.is_equal(cd_ab, dn): + args = [c, d, a, b, dn] + dep2 = deps.populate('rconst', args) + dep2.algebra = cd._val, ab._val, num, den + self.make_equal(dn, cd_ab, deps=dep2) + self.cache_dep('rconst', [c, d, a, b, dn], dep2) + add += [dep2] + + return add + + def do_algebra(self, name: str, args: list[Point]) -> list[Dependency]: + """Derive (but not add) new algebraic predicates.""" + if name == 'para': + a, b, dep = args + if gm.is_equiv(a, b): + return [] + (x, y), (m, n) = a._obj.points, b._obj.points + return self.add_para([x, y, m, n], dep) + + if name == 'aconst': + a, b, n, d, dep = args + ab, ba, why = self.get_or_create_angle_d(a, b, deps=None) + nd, dn = self.get_or_create_const_ang(n, d) + + (x, y), (m, n) = a._obj.points, b._obj.points + + if why: + dep0 = dep.populate('aconst', [x, y, m, n, nd]) + dep = EmptyDependency(level=dep.level, rule_name=None) + dep.why = [dep0] + why + + a, b = ab._d + (x, y), (m, n) = a._obj.points, b._obj.points + + added = [] + if not self.is_equal(ab, nd): + if nd == self.halfpi: + added += self.add_perp([x, y, m, n], dep) + # else: + name = 'aconst' + args = [x, y, m, n, nd] + dep1 = dep.populate(name, args) + self.cache_dep(name, args, dep1) + self.make_equal(nd, ab, deps=dep1) + added += [dep1] + + if not self.is_equal(ba, dn): + if dn == self.halfpi: + added += self.add_perp([m, n, x, y], dep) + name = 'aconst' + args = [m, n, x, y, dn] + dep2 = dep.populate(name, args) + self.cache_dep(name, args, dep2) + self.make_equal(dn, ba, deps=dep2) + added += [dep2] + return added + + if name == 'rconst': + a, b, c, d, num, den, dep = args + return self.add_eqrat_const([a, b, c, d, num, den], dep) + + if name == 'eqangle': + d1, d2, d3, d4, dep = args + a, b = d1._obj.points + c, d = d2._obj.points + e, f = d3._obj.points + g, h = d4._obj.points + + return self.add_eqangle([a, b, c, d, e, f, g, h], dep) + + if name == 'eqratio': + d1, d2, d3, d4, dep = args + a, b = d1._obj.points + c, d = d2._obj.points + e, f = d3._obj.points + g, h = d4._obj.points + + return self.add_eqratio([a, b, c, d, e, f, g, h], dep) + + if name in ['cong', 'cong2']: + a, b, c, d, dep = args + if not (a != b and c != d and (a != c or b != d)): + return [] + return self.add_cong([a, b, c, d], dep) + + return [] + + def derive_algebra( + self, level: int, verbose: bool = False + ) -> tuple[ + dict[str, list[tuple[Point, ...]]], dict[str, [tuple[Point, ...]]] + ]: + """Derive new algebraic predicates.""" + derives = {} + ang_derives = self.derive_angle_algebra(level, verbose=verbose) + dist_derives = self.derive_distance_algebra(level, verbose=verbose) + rat_derives = self.derive_ratio_algebra(level, verbose=verbose) + + derives.update(ang_derives) + derives.update(dist_derives) + derives.update(rat_derives) + + # Separate eqangle and eqratio derivations + # As they are too numerous => slow down DD+AR. + # & reserve them only for last effort. + eqs = {'eqangle': derives.pop('eqangle'), 'eqratio': derives.pop('eqratio')} + return derives, eqs + + def derive_ratio_algebra( + self, level: int, verbose: bool = False + ) -> dict[str, list[tuple[Point, ...]]]: + """Derive new eqratio predicates.""" + added = {'cong2': [], 'eqratio': []} + + for x in self.rtable.get_all_eqs_and_why(): + x, why = x[:-1], x[-1] + dep = EmptyDependency(level=level, rule_name='a01') + dep.why = why + + if len(x) == 2: + a, b = x + if gm.is_equiv(a, b): + continue + + (m, n), (p, q) = a._obj.points, b._obj.points + added['cong2'].append((m, n, p, q, dep)) + + if len(x) == 4: + a, b, c, d = x + added['eqratio'].append((a, b, c, d, dep)) + + return added + + def derive_angle_algebra( + self, level: int, verbose: bool = False + ) -> dict[str, list[tuple[Point, ...]]]: + """Derive new eqangles predicates.""" + added = {'eqangle': [], 'aconst': [], 'para': []} + + for x in self.atable.get_all_eqs_and_why(): + x, why = x[:-1], x[-1] + dep = EmptyDependency(level=level, rule_name='a02') + dep.why = why + + if len(x) == 2: + a, b = x + if gm.is_equiv(a, b): + continue + + (e, f), (p, q) = a._obj.points, b._obj.points + if not nm.check('para', [e, f, p, q]): + continue + + added['para'].append((a, b, dep)) + + if len(x) == 3: + a, b, (n, d) = x + + (e, f), (p, q) = a._obj.points, b._obj.points + if not nm.check('aconst', [e, f, p, q, n, d]): + continue + + added['aconst'].append((a, b, n, d, dep)) + + if len(x) == 4: + a, b, c, d = x + added['eqangle'].append((a, b, c, d, dep)) + + return added + + def derive_distance_algebra( + self, level: int, verbose: bool = False + ) -> dict[str, list[tuple[Point, ...]]]: + """Derive new cong predicates.""" + added = {'inci': [], 'cong': [], 'rconst': []} + for x in self.dtable.get_all_eqs_and_why(): + x, why = x[:-1], x[-1] + dep = EmptyDependency(level=level, rule_name='a00') + dep.why = why + + if len(x) == 2: + a, b = x + if a == b: + continue + + dep.name = f'inci {a.name} {b.name}' + added['inci'].append((x, dep)) + + if len(x) == 4: + a, b, c, d = x + if not (a != b and c != d and (a != c or b != d)): + continue + added['cong'].append((a, b, c, d, dep)) + + if len(x) == 6: + a, b, c, d, num, den = x + if not (a != b and c != d and (a != c or b != d)): + continue + added['rconst'].append((a, b, c, d, num, den, dep)) + + return added + + @classmethod + def build_problem( + cls, + pr: problem.Problem, + definitions: dict[str, problem.Definition], + verbose: bool = True, + init_copy: bool = True, + ) -> tuple[Graph, list[Dependency]]: + """Build a problem into a gr.Graph object.""" + check = False + g = None + added = None + if verbose: + logging.info(pr.url) + logging.info(pr.txt()) + while not check: + try: + g = Graph() + added = [] + plevel = 0 + for clause in pr.clauses: + adds, plevel = g.add_clause( + clause, plevel, definitions, verbose=verbose + ) + added += adds + g.plevel = plevel + + except (nm.InvalidLineIntersectError, nm.InvalidQuadSolveError): + continue + except DepCheckFailError: + continue + except (PointTooCloseError, PointTooFarError): + continue + + if not pr.goal: + break + + args = list(map(lambda x: g.get(x, lambda: int(x)), pr.goal.args)) + check = nm.check(pr.goal.name, args) + + g.url = pr.url + g.build_def = (pr, definitions) + for add in added: + g.add_algebra(add, level=0) + + return g, added + + def all_points(self) -> list[Point]: + """Return all nodes of type Point.""" + return list(self.type2nodes[Point]) + + def all_nodes(self) -> list[gm.Node]: + """Return all nodes.""" + return list(self._name2node.values()) + + def add_points(self, pnames: list[str]) -> list[Point]: + """Add new points with given names in list pnames.""" + result = [self.new_node(Point, name) for name in pnames] + self._name2point.update(zip(pnames, result)) + return result + + def names2nodes(self, pnames: list[str]) -> list[gm.Node]: + return [self._name2node[name] for name in pnames] + + def names2points( + self, pnames: list[str], create_new_point: bool = False + ) -> list[Point]: + """Return Point objects given names.""" + result = [] + for name in pnames: + if name not in self._name2node and not create_new_point: + raise ValueError(f'Cannot find point {name} in graph') + elif name in self._name2node: + obj = self._name2node[name] + else: + obj = self.new_node(Point, name) + result.append(obj) + + return result + + def names2points_or_int(self, pnames: list[str]) -> list[Point]: + """Return Point objects given names.""" + result = [] + for name in pnames: + if name.isdigit(): + result += [int(name)] + elif 'pi/' in name: + n, d = name.split('pi/') + ang, _ = self.get_or_create_const_ang(int(n), int(d)) + result += [ang] + elif '/' in name: + n, d = name.split('/') + rat, _ = self.get_or_create_const_rat(int(n), int(d)) + result += [rat] + else: + result += [self._name2point[name]] + + return result + + def get(self, pointname: str, default_fn: Callable[str, Point]) -> Point: + if pointname in self._name2point: + return self._name2point[pointname] + if pointname in self._name2node: + return self._name2node[pointname] + return default_fn() + + def new_node(self, oftype: Type[gm.Node], name: str = '') -> gm.Node: + node = oftype(name, self) + + self.type2nodes[oftype].append(node) + self._name2node[name] = node + + if isinstance(node, Point): + self._name2point[name] = node + + return node + + def merge(self, nodes: list[gm.Node], deps: Dependency) -> gm.Node: + """Merge all nodes.""" + if len(nodes) < 2: + return + + node0, *nodes1 = nodes + all_nodes = self.type2nodes[type(node0)] + + # find node0 that exists in all_nodes to be the rep + # and merge all other nodes into node0 + for node in nodes: + if node in all_nodes: + node0 = node + nodes1 = [n for n in nodes if n != node0] + break + return self.merge_into(node0, nodes1, deps) + + def merge_into( + self, node0: gm.Node, nodes1: list[gm.Node], deps: Dependency + ) -> gm.Node: + """Merge nodes1 into a single node0.""" + node0.merge(nodes1, deps) + for n in nodes1: + if n.rep() != n: + self.remove([n]) + + nodes = [node0] + nodes1 + if any([node._val for node in nodes]): + for node in nodes: + self.connect_val(node, deps=None) + + vals1 = [n._val for n in nodes1] + node0._val.merge(vals1, deps) + + for v in vals1: + if v.rep() != v: + self.remove([v]) + + return node0 + + def remove(self, nodes: list[gm.Node]) -> None: + """Remove nodes out of self because they are merged.""" + if not nodes: + return + + for node in nodes: + all_nodes = self.type2nodes[type(nodes[0])] + + if node in all_nodes: + all_nodes.remove(node) + + if node.name in self._name2node.values(): + self._name2node.pop(node.name) + + def connect(self, a: gm.Node, b: gm.Node, deps: Dependency) -> None: + a.connect_to(b, deps) + b.connect_to(a, deps) + + def connect_val(self, node: gm.Node, deps: Dependency) -> gm.Node: + """Connect a node into its value (equality) node.""" + if node._val: + return node._val + name = None + if isinstance(node, Line): + name = 'd(' + node.name + ')' + if isinstance(node, Angle): + name = 'm(' + node.name + ')' + if isinstance(node, Segment): + name = 'l(' + node.name + ')' + if isinstance(node, Ratio): + name = 'r(' + node.name + ')' + v = self.new_node(gm.val_type(node), name) + self.connect(node, v, deps=deps) + return v + + def is_equal(self, x: gm.Node, y: gm.Node, level: int = None) -> bool: + return gm.is_equal(x, y, level) + + def add_piece( + self, name: str, args: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add a new predicate.""" + if name in ['coll', 'collx']: + return self.add_coll(args, deps) + elif name == 'para': + return self.add_para(args, deps) + elif name == 'perp': + return self.add_perp(args, deps) + elif name == 'midp': + return self.add_midp(args, deps) + elif name == 'cong': + return self.add_cong(args, deps) + elif name == 'circle': + return self.add_circle(args, deps) + elif name == 'cyclic': + return self.add_cyclic(args, deps) + elif name in ['eqangle', 'eqangle6']: + return self.add_eqangle(args, deps) + elif name in ['eqratio', 'eqratio6']: + return self.add_eqratio(args, deps) + # numerical! + elif name == 's_angle': + return self.add_s_angle(args, deps) + elif name == 'aconst': + a, b, c, d, ang = args + + if isinstance(ang, str): + name = ang + else: + name = ang.name + + num, den = name.split('pi/') + num, den = int(num), int(den) + return self.add_aconst([a, b, c, d, num, den], deps) + elif name == 's_angle': + b, x, a, b, ang = ( # pylint: disable=redeclared-assigned-name,unused-variable + args + ) + + if isinstance(ang, str): + name = ang + else: + name = ang.name + + n, d = name.split('pi/') + ang = int(n) * 180 / int(d) + return self.add_s_angle([a, b, x, ang], deps) + elif name == 'rconst': + a, b, c, d, rat = args + + if isinstance(rat, str): + name = rat + else: + name = rat.name + + num, den = name.split('/') + num, den = int(num), int(den) + return self.add_eqrat_const([a, b, c, d, num, den], deps) + + # composite pieces: + elif name == 'cong2': + return self.add_cong2(args, deps) + elif name == 'eqratio3': + return self.add_eqratio3(args, deps) + elif name == 'eqratio4': + return self.add_eqratio4(args, deps) + elif name == 'simtri': + return self.add_simtri(args, deps) + elif name == 'contri': + return self.add_contri(args, deps) + elif name == 'simtri2': + return self.add_simtri2(args, deps) + elif name == 'contri2': + return self.add_contri2(args, deps) + elif name == 'simtri*': + return self.add_simtri_check(args, deps) + elif name == 'contri*': + return self.add_contri_check(args, deps) + elif name in ['acompute', 'rcompute']: + dep = deps.populate(name, args) + self.cache_dep(name, args, dep) + return [dep] + elif name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']: + dep = deps.populate(name, args) + self.cache_dep(name, args, dep) + return [dep] + elif name in ['ind']: + return [] + raise ValueError(f'Not recognize {name}') + + def check(self, name: str, args: list[Point]) -> bool: + """Symbolically check if a predicate is True.""" + if name == 'ncoll': + return self.check_ncoll(args) + if name == 'npara': + return self.check_npara(args) + if name == 'nperp': + return self.check_nperp(args) + if name == 'midp': + return self.check_midp(args) + if name == 'cong': + return self.check_cong(args) + if name == 'perp': + return self.check_perp(args) + if name == 'para': + return self.check_para(args) + if name == 'coll': + return self.check_coll(args) + if name == 'cyclic': + return self.check_cyclic(args) + if name == 'circle': + return self.check_circle(args) + if name == 'aconst': + return self.check_aconst(args) + if name == 'rconst': + return self.check_rconst(args) + if name == 'acompute': + return self.check_acompute(args) + if name == 'rcompute': + return self.check_rcompute(args) + if name in ['eqangle', 'eqangle6']: + if len(args) == 5: + return self.check_aconst(args) + return self.check_eqangle(args) + if name in ['eqratio', 'eqratio6']: + if len(args) == 5: + return self.check_rconst(args) + return self.check_eqratio(args) + if name in ['simtri', 'simtri2', 'simtri*']: + return self.check_simtri(args) + if name in ['contri', 'contri2', 'contri*']: + return self.check_contri(args) + if name == 'sameside': + return self.check_sameside(args) + if name in 'diff': + a, b = args + return not a.num.close(b.num) + if name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']: + return self.in_cache(name, args) + if name in ['ind']: + return True + raise ValueError(f'Not recognize {name}') + + def get_lines_thru_all(self, *points: list[gm.Point]) -> list[Line]: + line2count = defaultdict(lambda: 0) + points = set(points) + for p in points: + for l in p.neighbors(Line): + line2count[l] += 1 + return [l for l, count in line2count.items() if count == len(points)] + + def _get_line(self, a: Point, b: Point) -> Optional[Line]: + linesa = a.neighbors(Line) + for l in b.neighbors(Line): + if l in linesa: + return l + return None + + def _get_line_all(self, a: Point, b: Point) -> Generator[Line, None, None]: + linesa = a.neighbors(Line, do_rep=False) + linesb = b.neighbors(Line, do_rep=False) + for l in linesb: + if l in linesa: + yield l + + def _get_lines(self, *points: list[Point]) -> list[Line]: + """Return all lines that connect to >= 2 points.""" + line2count = defaultdict(lambda: 0) + for p in points: + for l in p.neighbors(Line): + line2count[l] += 1 + return [l for l, count in line2count.items() if count >= 2] + + def get_circle_thru_triplet(self, p1: Point, p2: Point, p3: Point) -> Circle: + p1, p2, p3 = sorted([p1, p2, p3], key=lambda x: x.name) + if (p1, p2, p3) in self._triplet2circle: + return self._triplet2circle[(p1, p2, p3)] + return self.get_new_circle_thru_triplet(p1, p2, p3) + + def get_new_circle_thru_triplet( + self, p1: Point, p2: Point, p3: Point + ) -> Circle: + """Get a new Circle that goes thru three given Points.""" + p1, p2, p3 = sorted([p1, p2, p3], key=lambda x: x.name) + name = p1.name.lower() + p2.name.lower() + p3.name.lower() + circle = self.new_node(Circle, f'({name})') + circle.num = nm.Circle(p1=p1.num, p2=p2.num, p3=p3.num) + circle.points = p1, p2, p3 + + self.connect(p1, circle, deps=None) + self.connect(p2, circle, deps=None) + self.connect(p3, circle, deps=None) + self._triplet2circle[(p1, p2, p3)] = circle + return circle + + def get_line_thru_pair(self, p1: Point, p2: Point) -> Line: + if (p1, p2) in self._pair2line: + return self._pair2line[(p1, p2)] + if (p2, p1) in self._pair2line: + return self._pair2line[(p2, p1)] + return self.get_new_line_thru_pair(p1, p2) + + def get_new_line_thru_pair(self, p1: Point, p2: Point) -> Line: + if p1.name.lower() > p2.name.lower(): + p1, p2 = p2, p1 + name = p1.name.lower() + p2.name.lower() + line = self.new_node(Line, name) + line.num = nm.Line(p1.num, p2.num) + line.points = p1, p2 + + self.connect(p1, line, deps=None) + self.connect(p2, line, deps=None) + self._pair2line[(p1, p2)] = line + return line + + def get_line_thru_pair_why( + self, p1: Point, p2: Point + ) -> tuple[Line, list[Dependency]]: + """Get one line thru two given points and the corresponding dependency list.""" + if p1.name.lower() > p2.name.lower(): + p1, p2 = p2, p1 + if (p1, p2) in self._pair2line: + return self._pair2line[(p1, p2)].rep_and_why() + + l, why = gm.line_of_and_why([p1, p2]) + if l is None: + l = self.get_new_line_thru_pair(p1, p2) + why = [] + return l, why + + def coll_dep(self, points: list[Point], p: Point) -> list[Dependency]: + """Return the dep(.why) explaining why p is coll with points.""" + for p1, p2 in utils.comb2(points): + if self.check_coll([p1, p2, p]): + dep = Dependency('coll', [p1, p2, p], None, None) + return dep.why_me_or_cache(self, None) + + def add_coll( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add a predicate that `points` are collinear.""" + points = list(set(points)) + og_points = list(points) + + all_lines = [] + for p1, p2 in utils.comb2(points): + all_lines.append(self.get_line_thru_pair(p1, p2)) + points = sum([l.neighbors(Point) for l in all_lines], []) + points = list(set(points)) + + existed = set() + new = set() + for p1, p2 in utils.comb2(points): + if p1.name > p2.name: + p1, p2 = p2, p1 + if (p1, p2) in self._pair2line: + line = self._pair2line[(p1, p2)] + existed.add(line) + else: + line = self.get_new_line_thru_pair(p1, p2) + new.add(line) + + existed = sorted(existed, key=lambda l: l.name) + new = sorted(new, key=lambda l: l.name) + + existed, new = list(existed), list(new) + if not existed: + line0, *lines = new + else: + line0, lines = existed[0], existed[1:] + new + + add = [] + line0, why0 = line0.rep_and_why() + a, b = line0.points + for line in lines: + c, d = line.points + args = list({a, b, c, d}) + if len(args) < 3: + continue + + whys = [] + for x in args: + if x not in og_points: + whys.append(self.coll_dep(og_points, x)) + + abcd_deps = deps + if whys + why0: + dep0 = deps.populate('coll', og_points) + abcd_deps = EmptyDependency(level=deps.level, rule_name=None) + abcd_deps.why = [dep0] + whys + + is_coll = self.check_coll(args) + dep = abcd_deps.populate('coll', args) + self.cache_dep('coll', args, dep) + self.merge_into(line0, [line], dep) + + if not is_coll: + add += [dep] + + return add + + def check_coll(self, points: list[Point]) -> bool: + points = list(set(points)) + if len(points) < 3: + return True + line2count = defaultdict(lambda: 0) + for p in points: + for l in p.neighbors(Line): + line2count[l] += 1 + return any([count == len(points) for _, count in line2count.items()]) + + def why_coll(self, args: tuple[Line, list[Point]]) -> list[Dependency]: + line, points = args + return line.why_coll(points) + + def check_ncoll(self, points: list[Point]) -> bool: + if self.check_coll(points): + return False + return not nm.check_coll([p.num for p in points]) + + def check_sameside(self, points: list[Point]) -> bool: + return nm.check_sameside([p.num for p in points]) + + def make_equal(self, x: gm.Node, y: gm.Node, deps: Dependency) -> None: + """Make that two nodes x and y are equal, i.e. merge their value node.""" + if x.val is None: + x, y = y, x + + self.connect_val(x, deps=None) + self.connect_val(y, deps=None) + vx = x._val + vy = y._val + + if vx == vy: + return + + merges = [vx, vy] + + if ( + isinstance(x, Angle) + and x not in self.aconst.values() + and y not in self.aconst.values() + and x.directions == y.directions[::-1] + and x.directions[0] != x.directions[1] + ): + merges = [self.vhalfpi, vx, vy] + + self.merge(merges, deps) + + def merge_vals(self, vx: gm.Node, vy: gm.Node, deps: Dependency) -> None: + if vx == vy: + return + merges = [vx, vy] + self.merge(merges, deps) + + def why_equal(self, x: gm.Node, y: gm.Node, level: int) -> list[Dependency]: + return gm.why_equal(x, y, level) + + def _why_coll4( + self, + a: Point, + b: Point, + ab: Line, + c: Point, + d: Point, + cd: Line, + level: int, + ) -> list[Dependency]: + return self._why_coll2(a, b, ab, level) + self._why_coll2(c, d, cd, level) + + def _why_coll8( + self, + a: Point, + b: Point, + ab: Line, + c: Point, + d: Point, + cd: Line, + m: Point, + n: Point, + mn: Line, + p: Point, + q: Point, + pq: Line, + level: int, + ) -> list[Dependency]: + """Dependency list of why 8 points are collinear.""" + why8 = self._why_coll4(a, b, ab, c, d, cd, level) + why8 += self._why_coll4(m, n, mn, p, q, pq, level) + return why8 + + def add_para( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add a new predicate that 4 points (2 lines) are parallel.""" + a, b, c, d = points + ab, why1 = self.get_line_thru_pair_why(a, b) + cd, why2 = self.get_line_thru_pair_why(c, d) + + is_equal = self.is_equal(ab, cd) + + (a, b), (c, d) = ab.points, cd.points + + dep0 = deps.populate('para', points) + deps = EmptyDependency(level=deps.level, rule_name=None) + + deps = deps.populate('para', [a, b, c, d]) + deps.why = [dep0] + why1 + why2 + + self.make_equal(ab, cd, deps) + deps.algebra = ab._val, cd._val + + self.cache_dep('para', [a, b, c, d], deps) + if not is_equal: + return [deps] + return [] + + def why_para(self, args: list[Point]) -> list[Dependency]: + ab, cd, lvl = args + return self.why_equal(ab, cd, lvl) + + def check_para_or_coll(self, points: list[Point]) -> bool: + return self.check_para(points) or self.check_coll(points) + + def check_para(self, points: list[Point]) -> bool: + a, b, c, d = points + if (a == b) or (c == d): + return False + ab = self._get_line(a, b) + cd = self._get_line(c, d) + if not ab or not cd: + return False + + return self.is_equal(ab, cd) + + def check_npara(self, points: list[Point]) -> bool: + if self.check_para(points): + return False + return not nm.check_para([p.num for p in points]) + + def _get_angle( + self, d1: Direction, d2: Direction + ) -> tuple[Angle, Optional[Angle]]: + for a in self.type2nodes[Angle]: + if a.directions == (d1, d2): + return a, a.opposite + return None, None + + def get_first_angle( + self, l1: Line, l2: Line + ) -> tuple[Angle, list[Dependency]]: + """Get a first angle between line l1 and line l2.""" + d1, d2 = l1._val, l2._val + + d1s = d1.all_reps() + d2s = d2.all_reps() + + found = d1.first_angle(d2s) + if found is None: + found = d2.first_angle(d1s) + if found is None: + return None, [] + ang, x2, x1 = found + found = ang.opposite, x1, x2 + + ang, x1, x2 = found + return ang, d1.deps_upto(x1) + d2.deps_upto(x2) + + def _get_or_create_angle( + self, l1: Line, l2: Line, deps: Dependency + ) -> tuple[Angle, Angle, list[Dependency]]: + return self.get_or_create_angle_d(l1._val, l2._val, deps) + + def get_or_create_angle_d( + self, d1: Direction, d2: Direction, deps: Dependency + ) -> tuple[Angle, Angle, list[Dependency]]: + """Get or create an angle between two Direction d1 and d2.""" + for a in self.type2nodes[Angle]: + if a.directions == (d1.rep(), d2.rep()): # directions = _d.rep() + d1_, d2_ = a._d + why1 = d1.why_equal([d1_], None) + d1_.why_rep() + why2 = d2.why_equal([d2_], None) + d2_.why_rep() + return a, a.opposite, why1 + why2 + + d1, why1 = d1.rep_and_why() + d2, why2 = d2.rep_and_why() + a12 = self.new_node(Angle, f'{d1.name}-{d2.name}') + a21 = self.new_node(Angle, f'{d2.name}-{d1.name}') + self.connect(d1, a12, deps) + self.connect(d2, a21, deps) + self.connect(a12, a21, deps) + a12.set_directions(d1, d2) + a21.set_directions(d2, d1) + a12.opposite = a21 + a21.opposite = a12 + return a12, a21, why1 + why2 + + def _add_para_or_coll( + self, + a: Point, + b: Point, + c: Point, + d: Point, + x: Point, + y: Point, + m: Point, + n: Point, + deps: EmptyDependency, + ) -> list[Dependency]: + """Add a new parallel or collinear predicate.""" + extends = [('perp', [x, y, m, n])] + if {a, b} == {x, y}: + pass + elif self.check_para([a, b, x, y]): + extends.append(('para', [a, b, x, y])) + elif self.check_coll([a, b, x, y]): + extends.append(('coll', set(list([a, b, x, y])))) + else: + return None + + if m in [c, d] or n in [c, d] or c in [m, n] or d in [m, n]: + pass + elif self.check_coll([c, d, m]): + extends.append(('coll', [c, d, m])) + elif self.check_coll([c, d, n]): + extends.append(('coll', [c, d, n])) + elif self.check_coll([c, m, n]): + extends.append(('coll', [c, m, n])) + elif self.check_coll([d, m, n]): + extends.append(('coll', [d, m, n])) + else: + deps = deps.extend_many(self, 'perp', [a, b, c, d], extends) + return self.add_para([c, d, m, n], deps) + + deps = deps.extend_many(self, 'perp', [a, b, c, d], extends) + return self.add_coll(list(set([c, d, m, n])), deps) + + def maybe_make_para_from_perp( + self, points: list[Point], deps: EmptyDependency + ) -> Optional[list[Dependency]]: + """Maybe add a new parallel predicate from perp predicate.""" + a, b, c, d = points + halfpi = self.aconst[(1, 2)] + for ang in halfpi.val.neighbors(Angle): + if ang == halfpi: + continue + d1, d2 = ang.directions + x, y = d1._obj.points + m, n = d2._obj.points + + for args in [ + (a, b, c, d, x, y, m, n), + (a, b, c, d, m, n, x, y), + (c, d, a, b, x, y, m, n), + (c, d, a, b, m, n, x, y), + ]: + args = args + (deps,) + add = self._add_para_or_coll(*args) + if add: + return add + + return None + + def add_perp( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add a new perpendicular predicate from 4 points (2 lines).""" + add = self.maybe_make_para_from_perp(points, deps) + if add is not None: + return add + + a, b, c, d = points + ab, why1 = self.get_line_thru_pair_why(a, b) + cd, why2 = self.get_line_thru_pair_why(c, d) + + (a, b), (c, d) = ab.points, cd.points + + if why1 + why2: + dep0 = deps.populate('perp', points) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why1 + why2 + + self.connect_val(ab, deps=None) + self.connect_val(cd, deps=None) + + if ab.val == cd.val: + raise ValueError(f'{ab.name} and {cd.name} Cannot be perp.') + + args = [a, b, c, d] + i = 0 + for x, y, xy in [(a, b, ab), (c, d, cd)]: + i += 1 + x_, y_ = xy._val._obj.points + if {x, y} == {x_, y_}: + continue + if deps: + deps = deps.extend(self, 'perp', list(args), 'para', [x, y, x_, y_]) + args[2 * i - 2] = x_ + args[2 * i - 1] = y_ + + a12, a21, why = self._get_or_create_angle(ab, cd, deps=None) + + if why: + dep0 = deps.populate('perp', [a, b, c, d]) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why + + dab, dcd = a12._d + a, b = dab._obj.points + c, d = dcd._obj.points + + is_equal = self.is_equal(a12, a21) + deps = deps.populate('perp', [a, b, c, d]) + deps.algebra = [dab, dcd] + self.make_equal(a12, a21, deps=deps) + + self.cache_dep('perp', [a, b, c, d], deps) + self.cache_dep('eqangle', [a, b, c, d, c, d, a, b], deps) + + if not is_equal: + return [deps] + return [] + + def why_perp( + self, args: list[Union[Point, list[Dependency]]] + ) -> list[Dependency]: + a, b, deps = args + return deps + self.why_equal(a, b, None) + + def check_perpl(self, ab: Line, cd: Line) -> bool: + if ab.val is None or cd.val is None: + return False + if ab.val == cd.val: + return False + a12, a21 = self._get_angle(ab.val, cd.val) + if a12 is None or a21 is None: + return False + return self.is_equal(a12, a21) + + def check_perp(self, points: list[Point]) -> bool: + a, b, c, d = points + ab = self._get_line(a, b) + cd = self._get_line(c, d) + if not ab or not cd: + return False + return self.check_perpl(ab, cd) + + def check_nperp(self, points: list[Point]) -> bool: + if self.check_perp(points): + return False + return not nm.check_perp([p.num for p in points]) + + def _get_segment(self, p1: Point, p2: Point) -> Optional[Segment]: + for s in self.type2nodes[Segment]: + if s.points == {p1, p2}: + return s + return None + + def _get_or_create_segment( + self, p1: Point, p2: Point, deps: Dependency + ) -> Segment: + """Get or create a Segment object between two Points p1 and p2.""" + if p1 == p2: + raise ValueError(f'Creating same 0-length segment {p1.name}') + + for s in self.type2nodes[Segment]: + if s.points == {p1, p2}: + return s + + if p1.name > p2.name: + p1, p2 = p2, p1 + s = self.new_node(Segment, name=f'{p1.name.upper()}{p2.name.upper()}') + self.connect(p1, s, deps=deps) + self.connect(p2, s, deps=deps) + s.points = {p1, p2} + return s + + def add_cong( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add that two segments (4 points) are congruent.""" + a, b, c, d = points + ab = self._get_or_create_segment(a, b, deps=None) + cd = self._get_or_create_segment(c, d, deps=None) + + is_equal = self.is_equal(ab, cd) + + dep = deps.populate('cong', [a, b, c, d]) + self.make_equal(ab, cd, deps=dep) + dep.algebra = ab._val, cd._val + + self.cache_dep('cong', [a, b, c, d], dep) + + result = [] + + if not is_equal: + result += [dep] + + if a not in [c, d] and b not in [c, d]: + return result + + if b in [c, d]: + a, b = b, a + if a == d: + c, d = d, c # pylint: disable=unused-variable + + result += self._maybe_add_cyclic_from_cong(a, b, d, dep) + return result + + def _maybe_add_cyclic_from_cong( + self, a: Point, b: Point, c: Point, cong_ab_ac: Dependency + ) -> list[Dependency]: + """Maybe add a new cyclic predicate from given congruent segments.""" + ab = self._get_or_create_segment(a, b, deps=None) + + # all eq segs with one end being a. + segs = [s for s in ab.val.neighbors(Segment) if a in s.points] + + # all points on circle (a, b) + points = [] + for s in segs: + x, y = list(s.points) + points.append(x if y == a else y) + + # for sure both b and c are in points + points = [p for p in points if p not in [b, c]] + + if len(points) < 2: + return [] + + x, y = points[:2] + + if self.check_cyclic([b, c, x, y]): + return [] + + ax = self._get_or_create_segment(a, x, deps=None) + ay = self._get_or_create_segment(a, y, deps=None) + why = ab._val.why_equal([ax._val, ay._val], level=None) + why += [cong_ab_ac] + + deps = EmptyDependency(cong_ab_ac.level, '') + deps.why = why + + return self.add_cyclic([b, c, x, y], deps) + + def check_cong(self, points: list[Point]) -> bool: + a, b, c, d = points + if {a, b} == {c, d}: + return True + + ab = self._get_segment(a, b) + cd = self._get_segment(c, d) + if ab is None or cd is None: + return False + return self.is_equal(ab, cd) + + def why_cong(self, args: tuple[Segment, Segment]) -> list[Dependency]: + ab, cd = args + return self.why_equal(ab, cd, None) + + def add_midp( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + m, a, b = points + add = self.add_coll(points, deps=deps) + add += self.add_cong([m, a, m, b], deps) + return add + + def why_midp( + self, args: tuple[Line, list[Point], Segment, Segment] + ) -> list[Dependency]: + line, points, ma, mb = args + return self.why_coll([line, points]) + self.why_cong([ma, mb]) + + def check_midp(self, points: list[Point]) -> bool: + if not self.check_coll(points): + return False + m, a, b = points + return self.check_cong([m, a, m, b]) + + def add_circle( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + o, a, b, c = points + add = self.add_cong([o, a, o, b], deps=deps) + add += self.add_cong([o, a, o, c], deps=deps) + return add + + def why_circle( + self, args: tuple[Segment, Segment, Segment] + ) -> list[Dependency]: + oa, ob, oc = args + return self.why_equal(oa, ob, None) and self.why_equal(oa, oc, None) + + def check_circle(self, points: list[Point]) -> bool: + o, a, b, c = points + return self.check_cong([o, a, o, b]) and self.check_cong([o, a, o, c]) + + def get_circles_thru_all(self, *points: list[Point]) -> list[Circle]: + circle2count = defaultdict(lambda: 0) + points = set(points) + for p in points: + for c in p.neighbors(Circle): + circle2count[c] += 1 + return [c for c, count in circle2count.items() if count == len(points)] + + def _get_circles(self, *points: list[Point]) -> list[Circle]: + circle2count = defaultdict(lambda: 0) + for p in points: + for c in p.neighbors(Circle): + circle2count[c] += 1 + return [c for c, count in circle2count.items() if count >= 3] + + def cyclic_dep(self, points: list[Point], p: Point) -> list[Dependency]: + for p1, p2, p3 in utils.comb3(points): + if self.check_cyclic([p1, p2, p3, p]): + dep = Dependency('cyclic', [p1, p2, p3, p], None, None) + return dep.why_me_or_cache(self, None) + + def add_cyclic( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add a new cyclic predicate that 4 points are concyclic.""" + points = list(set(points)) + og_points = list(points) + + all_circles = [] + for p1, p2, p3 in utils.comb3(points): + all_circles.append(self.get_circle_thru_triplet(p1, p2, p3)) + points = sum([c.neighbors(Point) for c in all_circles], []) + points = list(set(points)) + + existed = set() + new = set() + for p1, p2, p3 in utils.comb3(points): + p1, p2, p3 = sorted([p1, p2, p3], key=lambda x: x.name) + + if (p1, p2, p3) in self._triplet2circle: + circle = self._triplet2circle[(p1, p2, p3)] + existed.add(circle) + else: + circle = self.get_new_circle_thru_triplet(p1, p2, p3) + new.add(circle) + + existed = sorted(existed, key=lambda l: l.name) + new = sorted(new, key=lambda l: l.name) + + existed, new = list(existed), list(new) + if not existed: + circle0, *circles = new + else: + circle0, circles = existed[0], existed[1:] + new + + add = [] + circle0, why0 = circle0.rep_and_why() + a, b, c = circle0.points + for circle in circles: + d, e, f = circle.points + args = list({a, b, c, d, e, f}) + if len(args) < 4: + continue + whys = [] + for x in [a, b, c, d, e, f]: + if x not in og_points: + whys.append(self.cyclic_dep(og_points, x)) + abcdef_deps = deps + if whys + why0: + dep0 = deps.populate('cyclic', og_points) + abcdef_deps = EmptyDependency(level=deps.level, rule_name=None) + abcdef_deps.why = [dep0] + whys + + is_cyclic = self.check_cyclic(args) + + dep = abcdef_deps.populate('cyclic', args) + self.cache_dep('cyclic', args, dep) + self.merge_into(circle0, [circle], dep) + if not is_cyclic: + add += [dep] + + return add + + def check_cyclic(self, points: list[Point]) -> bool: + points = list(set(points)) + if len(points) < 4: + return True + circle2count = defaultdict(lambda: 0) + for p in points: + for c in p.neighbors(Circle): + circle2count[c] += 1 + return any([count == len(points) for _, count in circle2count.items()]) + + def make_equal_pairs( + self, + a: Point, + b: Point, + c: Point, + d: Point, + m: Point, + n: Point, + p: Point, + q: Point, + ab: Line, + cd: Line, + mn: Line, + pq: Line, + deps: EmptyDependency, + ) -> list[Dependency]: + """Add ab/cd = mn/pq in case either two of (ab,cd,mn,pq) are equal.""" + depname = 'eqratio' if isinstance(ab, Segment) else 'eqangle' + eqname = 'cong' if isinstance(ab, Segment) else 'para' + + is_equal = self.is_equal(mn, pq) + + if ab != cd: + dep0 = deps.populate(depname, [a, b, c, d, m, n, p, q]) + deps = EmptyDependency(level=deps.level, rule_name=None) + + dep = Dependency(eqname, [a, b, c, d], None, deps.level) + deps.why = [dep0, dep.why_me_or_cache(self, None)] + + elif eqname == 'para': # ab == cd. + colls = [a, b, c, d] + if len(set(colls)) > 2: + dep0 = deps.populate(depname, [a, b, c, d, m, n, p, q]) + deps = EmptyDependency(level=deps.level, rule_name=None) + + dep = Dependency('collx', colls, None, deps.level) + deps.why = [dep0, dep.why_me_or_cache(self, None)] + + deps = deps.populate(eqname, [m, n, p, q]) + self.make_equal(mn, pq, deps=deps) + + deps.algebra = mn._val, pq._val + self.cache_dep(eqname, [m, n, p, q], deps) + + if is_equal: + return [] + return [deps] + + def maybe_make_equal_pairs( + self, + a: Point, + b: Point, + c: Point, + d: Point, + m: Point, + n: Point, + p: Point, + q: Point, + ab: Line, + cd: Line, + mn: Line, + pq: Line, + deps: EmptyDependency, + ) -> Optional[list[Dependency]]: + """Add ab/cd = mn/pq in case maybe either two of (ab,cd,mn,pq) are equal.""" + level = deps.level + if self.is_equal(ab, cd, level): + return self.make_equal_pairs(a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps) + elif self.is_equal(mn, pq, level): + return self.make_equal_pairs( # pylint: disable=arguments-out-of-order + m, + n, + p, + q, + a, + b, + c, + d, + mn, + pq, + ab, + cd, + deps, + ) + elif self.is_equal(ab, mn, level): + return self.make_equal_pairs( # pylint: disable=arguments-out-of-order + a, + b, + m, + n, + c, + d, + p, + q, + ab, + mn, + cd, + pq, + deps, + ) + elif self.is_equal(cd, pq, level): + return self.make_equal_pairs( # pylint: disable=arguments-out-of-order + c, + d, + p, + q, + a, + b, + m, + n, + cd, + pq, + ab, + mn, + deps, + ) + else: + return None + + def _add_eqangle( + self, + a: Point, + b: Point, + c: Point, + d: Point, + m: Point, + n: Point, + p: Point, + q: Point, + ab: Line, + cd: Line, + mn: Line, + pq: Line, + deps: EmptyDependency, + ) -> list[Dependency]: + """Add eqangle core.""" + if deps: + deps = deps.copy() + + args = [a, b, c, d, m, n, p, q] + i = 0 + for x, y, xy in [(a, b, ab), (c, d, cd), (m, n, mn), (p, q, pq)]: + i += 1 + x_, y_ = xy._val._obj.points + if {x, y} == {x_, y_}: + continue + if deps: + deps = deps.extend(self, 'eqangle', list(args), 'para', [x, y, x_, y_]) + + args[2 * i - 2] = x_ + args[2 * i - 1] = y_ + + add = [] + ab_cd, cd_ab, why1 = self._get_or_create_angle(ab, cd, deps=None) + mn_pq, pq_mn, why2 = self._get_or_create_angle(mn, pq, deps=None) + + why = why1 + why2 + if why: + dep0 = deps.populate('eqangle', args) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why + + dab, dcd = ab_cd._d + dmn, dpq = mn_pq._d + + a, b = dab._obj.points + c, d = dcd._obj.points + m, n = dmn._obj.points + p, q = dpq._obj.points + + is_eq1 = self.is_equal(ab_cd, mn_pq) + deps1 = None + if deps: + deps1 = deps.populate('eqangle', [a, b, c, d, m, n, p, q]) + deps1.algebra = [dab, dcd, dmn, dpq] + if not is_eq1: + add += [deps1] + self.cache_dep('eqangle', [a, b, c, d, m, n, p, q], deps1) + self.make_equal(ab_cd, mn_pq, deps=deps1) + + is_eq2 = self.is_equal(cd_ab, pq_mn) + deps2 = None + if deps: + deps2 = deps.populate('eqangle', [c, d, a, b, p, q, m, n]) + deps2.algebra = [dcd, dab, dpq, dmn] + if not is_eq2: + add += [deps2] + self.cache_dep('eqangle', [c, d, a, b, p, q, m, n], deps2) + self.make_equal(cd_ab, pq_mn, deps=deps2) + + return add + + def add_eqangle( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add eqangle made by 8 points in `points`.""" + if deps: + deps = deps.copy() + a, b, c, d, m, n, p, q = points + ab, why1 = self.get_line_thru_pair_why(a, b) + cd, why2 = self.get_line_thru_pair_why(c, d) + mn, why3 = self.get_line_thru_pair_why(m, n) + pq, why4 = self.get_line_thru_pair_why(p, q) + + a, b = ab.points + c, d = cd.points + m, n = mn.points + p, q = pq.points + + if deps and why1 + why2 + why3 + why4: + dep0 = deps.populate('eqangle', points) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why1 + why2 + why3 + why4 + + add = self.maybe_make_equal_pairs( + a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps + ) + + if add is not None: + return add + + self.connect_val(ab, deps=None) + self.connect_val(cd, deps=None) + self.connect_val(mn, deps=None) + self.connect_val(pq, deps=None) + + add = [] + if ( + ab.val != cd.val + and mn.val != pq.val + and (ab.val != mn.val or cd.val != pq.val) + ): + add += self._add_eqangle(a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps) + + if ( + ab.val != mn.val + and cd.val != pq.val + and (ab.val != cd.val or mn.val != pq.val) + ): + add += self._add_eqangle( # pylint: disable=arguments-out-of-order + a, + b, + m, + n, + c, + d, + p, + q, + ab, + mn, + cd, + pq, + deps, + ) + + return add + + def add_aconst( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add that an angle is equal to some constant.""" + a, b, c, d, num, den = points + nd, dn = self.get_or_create_const_ang(num, den) + + if nd == self.halfpi: + return self.add_perp([a, b, c, d], deps) + + ab, why1 = self.get_line_thru_pair_why(a, b) + cd, why2 = self.get_line_thru_pair_why(c, d) + + (a, b), (c, d) = ab.points, cd.points + if why1 + why2: + args = points[:-2] + [nd] + dep0 = deps.populate('aconst', args) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why1 + why2 + + self.connect_val(ab, deps=None) + self.connect_val(cd, deps=None) + + if ab.val == cd.val: + raise ValueError(f'{ab.name} - {cd.name} cannot be {nd.name}') + + args = [a, b, c, d, nd] + i = 0 + for x, y, xy in [(a, b, ab), (c, d, cd)]: + i += 1 + x_, y_ = xy._val._obj.points + if {x, y} == {x_, y_}: + continue + if deps: + deps = deps.extend(self, 'aconst', list(args), 'para', [x, y, x_, y_]) + args[2 * i - 2] = x_ + args[2 * i - 1] = y_ + + ab_cd, cd_ab, why = self._get_or_create_angle(ab, cd, deps=None) + if why: + dep0 = deps.populate('aconst', [a, b, c, d, nd]) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why + + dab, dcd = ab_cd._d + a, b = dab._obj.points + c, d = dcd._obj.points + + ang = int(num) * 180 / int(den) + add = [] + if not self.is_equal(ab_cd, nd): + deps1 = deps.populate('aconst', [a, b, c, d, nd]) + deps1.algebra = dab, dcd, ang % 180 + self.make_equal(ab_cd, nd, deps=deps1) + self.cache_dep('aconst', [a, b, c, d, nd], deps1) + add += [deps1] + + if not self.is_equal(cd_ab, dn): + deps2 = deps.populate('aconst', [c, d, a, b, dn]) + deps2.algebra = dcd, dab, 180 - ang % 180 + self.make_equal(cd_ab, dn, deps=deps2) + self.cache_dep('aconst', [c, d, a, b, dn], deps2) + add += [deps2] + return add + + def add_s_angle( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add that an angle abx is equal to constant y.""" + a, b, x, y = points + + n, d = ar.simplify(y % 180, 180) + nd, dn = self.get_or_create_const_ang(n, d) + + if nd == self.halfpi: + return self.add_perp([a, b, b, x], deps) + + ab, why1 = self.get_line_thru_pair_why(a, b) + bx, why2 = self.get_line_thru_pair_why(b, x) + + self.connect_val(ab, deps=None) + self.connect_val(bx, deps=None) + add = [] + + if ab.val == bx.val: + return add + + deps.why += why1 + why2 + + for p, q, pq in [(a, b, ab), (b, x, bx)]: + p_, q_ = pq.val._obj.points + if {p, q} == {p_, q_}: + continue + dep = Dependency('para', [p, q, p_, q_], None, deps.level) + deps.why += [dep.why_me_or_cache(self, None)] + + xba, abx, why = self._get_or_create_angle(bx, ab, deps=None) + if why: + dep0 = deps.populate('aconst', [b, x, a, b, nd]) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why + + dab, dbx = abx._d + a, b = dab._obj.points + c, x = dbx._obj.points + + if not self.is_equal(xba, nd): + deps1 = deps.populate('aconst', [c, x, a, b, nd]) + deps1.algebra = dbx, dab, y % 180 + + self.make_equal(xba, nd, deps=deps1) + self.cache_dep('aconst', [c, x, a, b, nd], deps1) + add += [deps1] + + if not self.is_equal(abx, dn): + deps2 = deps.populate('aconst', [a, b, c, x, dn]) + deps2.algebra = dab, dbx, 180 - (y % 180) + + self.make_equal(abx, dn, deps=deps2) + self.cache_dep('s_angle', [a, b, c, x, dn], deps2) + add += [deps2] + return add + + def check_aconst(self, points: list[Point], verbose: bool = False) -> bool: + """Check if the angle is equal to a certain constant.""" + a, b, c, d, nd = points + _ = verbose + if isinstance(nd, str): + name = nd + else: + name = nd.name + num, den = name.split('pi/') + ang, _ = self.get_or_create_const_ang(int(num), int(den)) + + ab = self._get_line(a, b) + cd = self._get_line(c, d) + if not ab or not cd: + return False + + if not (ab.val and cd.val): + return False + + for ang1, _, _ in gm.all_angles(ab._val, cd._val): + if self.is_equal(ang1, ang): + return True + return False + + def check_acompute(self, points: list[Point]) -> bool: + """Check if an angle has a constant value.""" + a, b, c, d = points + ab = self._get_line(a, b) + cd = self._get_line(c, d) + if not ab or not cd: + return False + + if not (ab.val and cd.val): + return False + + for ang0 in self.aconst.values(): + for ang in ang0.val.neighbors(Angle): + d1, d2 = ang.directions + if ab.val == d1 and cd.val == d2: + return True + return False + + def check_eqangle(self, points: list[Point]) -> bool: + """Check if two angles are equal.""" + a, b, c, d, m, n, p, q = points + + if {a, b} == {c, d} and {m, n} == {p, q}: + return True + if {a, b} == {m, n} and {c, d} == {p, q}: + return True + + if (a == b) or (c == d) or (m == n) or (p == q): + return False + ab = self._get_line(a, b) + cd = self._get_line(c, d) + mn = self._get_line(m, n) + pq = self._get_line(p, q) + + if {a, b} == {c, d} and mn and pq and self.is_equal(mn, pq): + return True + if {a, b} == {m, n} and cd and pq and self.is_equal(cd, pq): + return True + if {p, q} == {m, n} and ab and cd and self.is_equal(ab, cd): + return True + if {p, q} == {c, d} and ab and mn and self.is_equal(ab, mn): + return True + + if not ab or not cd or not mn or not pq: + return False + + if self.is_equal(ab, cd) and self.is_equal(mn, pq): + return True + if self.is_equal(ab, mn) and self.is_equal(cd, pq): + return True + + if not (ab.val and cd.val and mn.val and pq.val): + return False + + if (ab.val, cd.val) == (mn.val, pq.val) or (ab.val, mn.val) == ( + cd.val, + pq.val, + ): + return True + + for ang1, _, _ in gm.all_angles(ab._val, cd._val): + for ang2, _, _ in gm.all_angles(mn._val, pq._val): + if self.is_equal(ang1, ang2): + return True + + if self.check_perp([a, b, m, n]) and self.check_perp([c, d, p, q]): + return True + if self.check_perp([a, b, p, q]) and self.check_perp([c, d, m, n]): + return True + + return False + + def _get_ratio(self, l1: Length, l2: Length) -> tuple[Ratio, Ratio]: + for r in self.type2nodes[Ratio]: + if r.lengths == (l1, l2): + return r, r.opposite + return None, None + + def _get_or_create_ratio( + self, s1: Segment, s2: Segment, deps: Dependency + ) -> tuple[Ratio, Ratio, list[Dependency]]: + return self._get_or_create_ratio_l(s1._val, s2._val, deps) + + def _get_or_create_ratio_l( + self, l1: Length, l2: Length, deps: Dependency + ) -> tuple[Ratio, Ratio, list[Dependency]]: + """Get or create a new Ratio from two Lenghts l1 and l2.""" + for r in self.type2nodes[Ratio]: + if r.lengths == (l1.rep(), l2.rep()): + l1_, l2_ = r._l + why1 = l1.why_equal([l1_], None) + l1_.why_rep() + why2 = l2.why_equal([l2_], None) + l2_.why_rep() + return r, r.opposite, why1 + why2 + + l1, why1 = l1.rep_and_why() + l2, why2 = l2.rep_and_why() + r12 = self.new_node(Ratio, f'{l1.name}/{l2.name}') + r21 = self.new_node(Ratio, f'{l2.name}/{l1.name}') + self.connect(l1, r12, deps) + self.connect(l2, r21, deps) + self.connect(r12, r21, deps) + r12.set_lengths(l1, l2) + r21.set_lengths(l2, l1) + r12.opposite = r21 + r21.opposite = r12 + return r12, r21, why1 + why2 + + def add_cong2( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + m, n, a, b = points + add = [] + add += self.add_cong([m, a, n, a], deps) + add += self.add_cong([m, b, n, b], deps) + return add + + def add_eqratio3( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add three eqratios through a list of 6 points (due to parallel lines).""" + a, b, c, d, m, n = points + # a -- b + # m -- n + # c -- d + add = [] + add += self.add_eqratio([m, a, m, c, n, b, n, d], deps) + add += self.add_eqratio([a, m, a, c, b, n, b, d], deps) + add += self.add_eqratio([c, m, c, a, d, n, d, b], deps) + if m == n: + add += self.add_eqratio([m, a, m, c, a, b, c, d], deps) + return add + + def add_eqratio4( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + o, a, b, c, d = points + # o + # a b + # c d + add = self.add_eqratio3([a, b, c, d, o, o], deps) + add += self.add_eqratio([o, a, o, c, a, b, c, d], deps) + return add + + def _add_eqratio( + self, + a: Point, + b: Point, + c: Point, + d: Point, + m: Point, + n: Point, + p: Point, + q: Point, + ab: Segment, + cd: Segment, + mn: Segment, + pq: Segment, + deps: EmptyDependency, + ) -> list[Dependency]: + """Add a new eqratio from 8 points (core).""" + if deps: + deps = deps.copy() + + args = [a, b, c, d, m, n, p, q] + i = 0 + for x, y, xy in [(a, b, ab), (c, d, cd), (m, n, mn), (p, q, pq)]: + if {x, y} == set(xy.points): + continue + x_, y_ = list(xy.points) + if deps: + deps = deps.extend(self, 'eqratio', list(args), 'cong', [x, y, x_, y_]) + args[2 * i - 2] = x_ + args[2 * i - 1] = y_ + + add = [] + ab_cd, cd_ab, why1 = self._get_or_create_ratio(ab, cd, deps=None) + mn_pq, pq_mn, why2 = self._get_or_create_ratio(mn, pq, deps=None) + + why = why1 + why2 + if why: + dep0 = deps.populate('eqratio', args) + deps = EmptyDependency(level=deps.level, rule_name=None) + deps.why = [dep0] + why + + lab, lcd = ab_cd._l + lmn, lpq = mn_pq._l + + a, b = lab._obj.points + c, d = lcd._obj.points + m, n = lmn._obj.points + p, q = lpq._obj.points + + is_eq1 = self.is_equal(ab_cd, mn_pq) + deps1 = None + if deps: + deps1 = deps.populate('eqratio', [a, b, c, d, m, n, p, q]) + deps1.algebra = [ab._val, cd._val, mn._val, pq._val] + if not is_eq1: + add += [deps1] + self.cache_dep('eqratio', [a, b, c, d, m, n, p, q], deps1) + self.make_equal(ab_cd, mn_pq, deps=deps1) + + is_eq2 = self.is_equal(cd_ab, pq_mn) + deps2 = None + if deps: + deps2 = deps.populate('eqratio', [c, d, a, b, p, q, m, n]) + deps2.algebra = [cd._val, ab._val, pq._val, mn._val] + if not is_eq2: + add += [deps2] + self.cache_dep('eqratio', [c, d, a, b, p, q, m, n], deps2) + self.make_equal(cd_ab, pq_mn, deps=deps2) + return add + + def add_eqratio( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add a new eqratio from 8 points.""" + if deps: + deps = deps.copy() + a, b, c, d, m, n, p, q = points + ab = self._get_or_create_segment(a, b, deps=None) + cd = self._get_or_create_segment(c, d, deps=None) + mn = self._get_or_create_segment(m, n, deps=None) + pq = self._get_or_create_segment(p, q, deps=None) + + add = self.maybe_make_equal_pairs( + a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps + ) + + if add is not None: + return add + + self.connect_val(ab, deps=None) + self.connect_val(cd, deps=None) + self.connect_val(mn, deps=None) + self.connect_val(pq, deps=None) + + add = [] + if ( + ab.val != cd.val + and mn.val != pq.val + and (ab.val != mn.val or cd.val != pq.val) + ): + add += self._add_eqratio(a, b, c, d, m, n, p, q, ab, cd, mn, pq, deps) + + if ( + ab.val != mn.val + and cd.val != pq.val + and (ab.val != cd.val or mn.val != pq.val) + ): + add += self._add_eqratio( # pylint: disable=arguments-out-of-order + a, + b, + m, + n, + c, + d, + p, + q, + ab, + mn, + cd, + pq, + deps, + ) + return add + + def check_rconst(self, points: list[Point], verbose: bool = False) -> bool: + """Check whether a ratio is equal to some given constant.""" + _ = verbose + a, b, c, d, nd = points + if isinstance(nd, str): + name = nd + else: + name = nd.name + num, den = name.split('/') + rat, _ = self.get_or_create_const_rat(int(num), int(den)) + + ab = self._get_segment(a, b) + cd = self._get_segment(c, d) + + if not ab or not cd: + return False + + if not (ab.val and cd.val): + return False + + for rat1, _, _ in gm.all_ratios(ab._val, cd._val): + if self.is_equal(rat1, rat): + return True + return False + + def check_rcompute(self, points: list[Point]) -> bool: + """Check whether a ratio is equal to some constant.""" + a, b, c, d = points + ab = self._get_segment(a, b) + cd = self._get_segment(c, d) + + if not ab or not cd: + return False + + if not (ab.val and cd.val): + return False + + for rat0 in self.rconst.values(): + for rat in rat0.val.neighbors(Ratio): + l1, l2 = rat.lengths + if ab.val == l1 and cd.val == l2: + return True + return False + + def check_eqratio(self, points: list[Point]) -> bool: + """Check if 8 points make an eqratio predicate.""" + a, b, c, d, m, n, p, q = points + + if {a, b} == {c, d} and {m, n} == {p, q}: + return True + if {a, b} == {m, n} and {c, d} == {p, q}: + return True + + ab = self._get_segment(a, b) + cd = self._get_segment(c, d) + mn = self._get_segment(m, n) + pq = self._get_segment(p, q) + + if {a, b} == {c, d} and mn and pq and self.is_equal(mn, pq): + return True + if {a, b} == {m, n} and cd and pq and self.is_equal(cd, pq): + return True + if {p, q} == {m, n} and ab and cd and self.is_equal(ab, cd): + return True + if {p, q} == {c, d} and ab and mn and self.is_equal(ab, mn): + return True + + if not ab or not cd or not mn or not pq: + return False + + if self.is_equal(ab, cd) and self.is_equal(mn, pq): + return True + if self.is_equal(ab, mn) and self.is_equal(cd, pq): + return True + + if not (ab.val and cd.val and mn.val and pq.val): + return False + + if (ab.val, cd.val) == (mn.val, pq.val) or (ab.val, mn.val) == ( + cd.val, + pq.val, + ): + return True + + for rat1, _, _ in gm.all_ratios(ab._val, cd._val): + for rat2, _, _ in gm.all_ratios(mn._val, pq._val): + if self.is_equal(rat1, rat2): + return True + return False + + def add_simtri_check( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + if nm.same_clock(*[p.num for p in points]): + return self.add_simtri(points, deps) + return self.add_simtri2(points, deps) + + def add_contri_check( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + if nm.same_clock(*[p.num for p in points]): + return self.add_contri(points, deps) + return self.add_contri2(points, deps) + + def enum_sides( + self, points: list[Point] + ) -> Generator[list[Point], None, None]: + a, b, c, x, y, z = points + yield [a, b, x, y] + yield [b, c, y, z] + yield [c, a, z, x] + + def enum_triangle( + self, points: list[Point] + ) -> Generator[list[Point], None, None]: + a, b, c, x, y, z = points + yield [a, b, a, c, x, y, x, z] + yield [b, a, b, c, y, x, y, z] + yield [c, a, c, b, z, x, z, y] + + def enum_triangle2( + self, points: list[Point] + ) -> Generator[list[Point], None, None]: + a, b, c, x, y, z = points + yield [a, b, a, c, x, z, x, y] + yield [b, a, b, c, y, z, y, x] + yield [c, a, c, b, z, y, z, x] + + def add_simtri( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add two similar triangles.""" + add = [] + hashs = [d.hashed() for d in deps.why] + + for args in self.enum_triangle(points): + if problem.hashed('eqangle6', args) in hashs: + continue + add += self.add_eqangle(args, deps=deps) + + for args in self.enum_triangle(points): + if problem.hashed('eqratio6', args) in hashs: + continue + add += self.add_eqratio(args, deps=deps) + + return add + + def check_simtri(self, points: list[Point]) -> bool: + a, b, c, x, y, z = points + return self.check_eqangle([a, b, a, c, x, y, x, z]) and self.check_eqangle( + [b, a, b, c, y, x, y, z] + ) + + def add_simtri2( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add two similar reflected triangles.""" + add = [] + hashs = [d.hashed() for d in deps.why] + for args in self.enum_triangle2(points): + if problem.hashed('eqangle6', args) in hashs: + continue + add += self.add_eqangle(args, deps=deps) + + for args in self.enum_triangle(points): + if problem.hashed('eqratio6', args) in hashs: + continue + add += self.add_eqratio(args, deps=deps) + + return add + + def add_contri( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add two congruent triangles.""" + add = [] + hashs = [d.hashed() for d in deps.why] + for args in self.enum_triangle(points): + if problem.hashed('eqangle6', args) in hashs: + continue + add += self.add_eqangle(args, deps=deps) + + for args in self.enum_sides(points): + if problem.hashed('cong', args) in hashs: + continue + add += self.add_cong(args, deps=deps) + return add + + def check_contri(self, points: list[Point]) -> bool: + a, b, c, x, y, z = points + return ( + self.check_cong([a, b, x, y]) + and self.check_cong([b, c, y, z]) + and self.check_cong([c, a, z, x]) + ) + + def add_contri2( + self, points: list[Point], deps: EmptyDependency + ) -> list[Dependency]: + """Add two congruent reflected triangles.""" + add = [] + hashs = [d.hashed() for d in deps.why] + for args in self.enum_triangle2(points): + if problem.hashed('eqangle6', args) in hashs: + continue + add += self.add_eqangle(args, deps=deps) + + for args in self.enum_sides(points): + if problem.hashed('cong', args) in hashs: + continue + add += self.add_cong(args, deps=deps) + + return add + + def in_cache(self, name: str, args: list[Point]) -> bool: + return problem.hashed(name, args) in self.cache + + def cache_dep( + self, name: str, args: list[Point], premises: list[Dependency] + ) -> None: + hashed = problem.hashed(name, args) + if hashed in self.cache: + return + self.cache[hashed] = premises + + def all_same_line( + self, a: Point, b: Point + ) -> Generator[tuple[Point, Point], None, None]: + ab = self._get_line(a, b) + if ab is None: + return + for p1, p2 in utils.comb2(ab.neighbors(Point)): + if {p1, p2} != {a, b}: + yield p1, p2 + + def all_same_angle( + self, a: Point, b: Point, c: Point, d: Point + ) -> Generator[tuple[Point, Point, Point, Point], None, None]: + for x, y in self.all_same_line(a, b): + for m, n in self.all_same_line(c, d): + yield x, y, m, n + + def additionally_draw(self, name: str, args: list[Point]) -> None: + """Draw some extra line/circles for illustration purpose.""" + + if name in ['circle']: + center, point = args[:2] + circle = self.new_node(Circle, f'({center.name},{point.name})') + circle.num = nm.Circle(center.num, p1=point.num) + circle.points = center, point + + if name in ['on_circle', 'tangent']: + center, point = args[-2:] + circle = self.new_node(Circle, f'({center.name},{point.name})') + circle.num = nm.Circle(center.num, p1=point.num) + circle.points = center, point + + if name in ['incenter', 'excenter', 'incenter2', 'excenter2']: + d, a, b, c = [x for x in args[-4:]] + a, b, c = sorted([a, b, c], key=lambda x: x.name.lower()) + circle = self.new_node(Circle, f'({d.name},h.{a.name}{b.name})') + p = d.num.foot(nm.Line(a.num, b.num)) + circle.num = nm.Circle(d.num, p1=p) + circle.points = d, a, b, c + + if name in ['cc_tangent']: + o, a, w, b = args[-4:] + c1 = self.new_node(Circle, f'({o.name},{a.name})') + c1.num = nm.Circle(o.num, p1=a.num) + c1.points = o, a + + c2 = self.new_node(Circle, f'({w.name},{b.name})') + c2.num = nm.Circle(w.num, p1=b.num) + c2.points = w, b + + if name in ['ninepoints']: + a, b, c = args[-3:] + a, b, c = sorted([a, b, c], key=lambda x: x.name.lower()) + circle = self.new_node(Circle, f'(,m.{a.name}{b.name}{c.name})') + p1 = (b.num + c.num) * 0.5 + p2 = (c.num + a.num) * 0.5 + p3 = (a.num + b.num) * 0.5 + circle.num = nm.Circle(p1=p1, p2=p2, p3=p3) + circle.points = (None, None, a, b, c) + + if name in ['2l1c']: + a, b, c, o = args[:4] + a, b, c = sorted([a, b, c], key=lambda x: x.name.lower()) + circle = self.new_node(Circle, f'({o.name},{a.name}{b.name}{c.name})') + circle.num = nm.Circle(p1=a.num, p2=b.num, p3=c.num) + circle.points = (a, b, c) + + def add_clause( + self, + clause: problem.Clause, + plevel: int, + definitions: dict[str, problem.Definition], + verbose: int = False, + ) -> tuple[list[Dependency], int]: + """Add a new clause of construction, e.g. a new excenter.""" + existing_points = self.all_points() + new_points = [Point(name) for name in clause.points] + + new_points_dep_points = set() + new_points_dep = [] + + # Step 1: check for all deps. + for c in clause.constructions: + cdef = definitions[c.name] + + if len(cdef.construction.args) != len(c.args): + if len(cdef.construction.args) - len(c.args) == len(clause.points): + c.args = clause.points + c.args + else: + correct_form = ' '.join(cdef.points + ['=', c.name] + cdef.args) + raise ValueError('Argument mismatch. ' + correct_form) + + mapping = dict(zip(cdef.construction.args, c.args)) + c_name = 'midp' if c.name == 'midpoint' else c.name + deps = EmptyDependency(level=0, rule_name=problem.CONSTRUCTION_RULE) + deps.construction = Dependency(c_name, c.args, rule_name=None, level=0) + + for d in cdef.deps.constructions: + args = self.names2points([mapping[a] for a in d.args]) + new_points_dep_points.update(args) + if not self.check(d.name, args): + raise DepCheckFailError( + d.name + ' ' + ' '.join([x.name for x in args]) + ) + deps.why += [ + Dependency( + d.name, args, rule_name=problem.CONSTRUCTION_RULE, level=0 + ) + ] + + new_points_dep += [deps] + + # Step 2: draw. + def range_fn() -> ( + list[Union[nm.Point, nm.Line, nm.Circle, nm.HalfLine, nm.HoleCircle]] + ): + to_be_intersected = [] + for c in clause.constructions: + cdef = definitions[c.name] + mapping = dict(zip(cdef.construction.args, c.args)) + for n in cdef.numerics: + args = [mapping[a] for a in n.args] + args = list(map(lambda x: self.get(x, lambda: int(x)), args)) + to_be_intersected += nm.sketch(n.name, args) + + return to_be_intersected + + is_total_free = ( + len(clause.constructions) == 1 and clause.constructions[0].name in FREE + ) + is_semi_free = ( + len(clause.constructions) == 1 + and clause.constructions[0].name in INTERSECT + ) + + existing_points = [p.num for p in existing_points] + + def draw_fn() -> list[nm.Point]: + to_be_intersected = range_fn() + return nm.reduce(to_be_intersected, existing_points) + + rely_on = set() + for c in clause.constructions: + cdef = definitions[c.name] + mapping = dict(zip(cdef.construction.args, c.args)) + for n in cdef.numerics: + args = [mapping[a] for a in n.args] + args = list(map(lambda x: self.get(x, lambda: int(x)), args)) + rely_on.update([a for a in args if isinstance(a, Point)]) + + for p in rely_on: + p.change.update(new_points) + + nums = draw_fn() + for p, num, num0 in zip(new_points, nums, clause.nums): + p.co_change = new_points + if isinstance(num0, nm.Point): + num = num0 + elif isinstance(num0, (tuple, list)): + x, y = num0 + num = nm.Point(x, y) + + p.num = num + + # check two things. + if nm.check_too_close(nums, existing_points): + raise PointTooCloseError() + if nm.check_too_far(nums, existing_points): + raise PointTooFarError() + + # Commit: now that all conditions are passed. + # add these points to current graph. + for p in new_points: + self._name2point[p.name] = p + self._name2node[p.name] = p + self.type2nodes[Point].append(p) + + for p in new_points: + p.why = sum([d.why for d in new_points_dep], []) # to generate txt logs. + p.group = new_points + p.dep_points = new_points_dep_points + p.dep_points.update(new_points) + p.plevel = plevel + + # movement dependency: + rely_dict_0 = defaultdict(lambda: []) + + for c in clause.constructions: + cdef = definitions[c.name] + mapping = dict(zip(cdef.construction.args, c.args)) + for p, ps in cdef.rely.items(): + p = mapping[p] + ps = [mapping[x] for x in ps] + rely_dict_0[p].append(ps) + + rely_dict = {} + for p, pss in rely_dict_0.items(): + ps = sum(pss, []) + if len(pss) > 1: + ps = [x for x in ps if x != p] + + p = self._name2point[p] + ps = self.names2nodes(ps) + rely_dict[p] = ps + + for p in new_points: + p.rely_on = set(rely_dict.get(p, [])) + for x in p.rely_on: + if not hasattr(x, 'base_rely_on'): + x.base_rely_on = set() + p.base_rely_on = set.union(*[x.base_rely_on for x in p.rely_on] + [set()]) + if is_total_free or is_semi_free: + p.rely_on.add(p) + p.base_rely_on.add(p) + + plevel_done = set() + added = [] + basics = [] + # Step 3: build the basics. + for c, deps in zip(clause.constructions, new_points_dep): + cdef = definitions[c.name] + mapping = dict(zip(cdef.construction.args, c.args)) + + # not necessary for proofing, but for visualization. + c_args = list(map(lambda x: self.get(x, lambda: int(x)), c.args)) + self.additionally_draw(c.name, c_args) + + for points, bs in cdef.basics: + if points: + points = self.names2nodes([mapping[p] for p in points]) + points = [p for p in points if p not in plevel_done] + for p in points: + p.plevel = plevel + plevel_done.update(points) + plevel += 1 + else: + continue + + for b in bs: + if b.name != 'rconst': + args = [mapping[a] for a in b.args] + else: + num, den = map(int, b.args[-2:]) + rat, _ = self.get_or_create_const_rat(num, den) + args = [mapping[a] for a in b.args[:-2]] + [rat.name] + + args = list(map(lambda x: self.get(x, lambda: int(x)), args)) + + adds = self.add_piece(name=b.name, args=args, deps=deps) + basics.append((b.name, args, deps)) + if adds: + added += adds + for add in adds: + self.cache_dep(add.name, add.args, add) + + assert len(plevel_done) == len(new_points) + for p in new_points: + p.basics = basics + + return added, plevel + + def all_eqangle_same_lines(self) -> Generator[tuple[Point, ...], None, None]: + for l1, l2 in utils.perm2(self.type2nodes[Line]): + for a, b, c, d, e, f, g, h in utils.all_8points(l1, l2, l1, l2): + if (a, b, c, d) != (e, f, g, h): + yield a, b, c, d, e, f, g, h + + def all_eqangles_distinct_linepairss( + self, + ) -> Generator[tuple[Line, ...], None, None]: + """No eqangles betcause para-para, or para-corresponding, or same.""" + + for measure in self.type2nodes[Measure]: + angs = measure.neighbors(Angle) + line_pairss = [] + for ang in angs: + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + l1s = d1.neighbors(Line) + l2s = d2.neighbors(Line) + # Any pair in this is para-para. + para_para = list(utils.cross(l1s, l2s)) + line_pairss.append(para_para) + + for pairs1, pairs2 in utils.comb2(line_pairss): + for pair1, pair2 in utils.cross(pairs1, pairs2): + (l1, l2), (l3, l4) = pair1, pair2 + yield l1, l2, l3, l4 + + def all_eqangles_8points(self) -> Generator[tuple[Point, ...], None, None]: + """List all sets of 8 points that make two equal angles.""" + # Case 1: (l1-l2) = (l3-l4), including because l1//l3, l2//l4 (para-para) + angss = [] + for measure in self.type2nodes[Measure]: + angs = measure.neighbors(Angle) + angss.append(angs) + + # include the angs that do not have any measure. + angss.extend([[ang] for ang in self.type2nodes[Angle] if ang.val is None]) + + line_pairss = [] + for angs in angss: + line_pairs = set() + for ang in angs: + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + l1s = d1.neighbors(Line) + l2s = d2.neighbors(Line) + line_pairs.update(set(utils.cross(l1s, l2s))) + line_pairss.append(line_pairs) + + # include (d1, d2) in which d1 does not have any angles. + noang_ds = [d for d in self.type2nodes[Direction] if not d.neighbors(Angle)] + + for d1 in noang_ds: + for d2 in self.type2nodes[Direction]: + if d1 == d2: + continue + l1s = d1.neighbors(Line) + l2s = d2.neighbors(Line) + if len(l1s) < 2 and len(l2s) < 2: + continue + line_pairss.append(set(utils.cross(l1s, l2s))) + line_pairss.append(set(utils.cross(l2s, l1s))) + + # Case 2: d1 // d2 => (d1-d3) = (d2-d3) + # include lines that does not have any direction. + nodir_ls = [l for l in self.type2nodes[Line] if l.val is None] + + for line in nodir_ls: + for d in self.type2nodes[Direction]: + l1s = d.neighbors(Line) + if len(l1s) < 2: + continue + l2s = [line] + line_pairss.append(set(utils.cross(l1s, l2s))) + line_pairss.append(set(utils.cross(l2s, l1s))) + + record = set() + for line_pairs in line_pairss: + for pair1, pair2 in utils.perm2(list(line_pairs)): + (l1, l2), (l3, l4) = pair1, pair2 + if l1 == l2 or l3 == l4: + continue + if (l1, l2) == (l3, l4): + continue + if (l1, l2, l3, l4) in record: + continue + record.add((l1, l2, l3, l4)) + for a, b, c, d, e, f, g, h in utils.all_8points(l1, l2, l3, l4): + yield (a, b, c, d, e, f, g, h) + + for a, b, c, d, e, f, g, h in self.all_eqangle_same_lines(): + yield a, b, c, d, e, f, g, h + + def all_eqangles_6points(self) -> Generator[tuple[Point, ...], None, None]: + """List all sets of 6 points that make two equal angles.""" + record = set() + for a, b, c, d, e, f, g, h in self.all_eqangles_8points(): + if ( + a not in (c, d) + and b not in (c, d) + or e not in (g, h) + and f not in (g, h) + ): + continue + + if b in (c, d): + a, b = b, a # now a in c, d + if f in (g, h): + e, f = f, e # now e in g, h + if a == d: + c, d = d, c # now a == c + if e == h: + g, h = h, g # now e == g + if (a, b, c, d, e, f, g, h) in record: + continue + record.add((a, b, c, d, e, f, g, h)) + yield a, b, c, d, e, f, g, h # where a==c, e==g + + def all_paras(self) -> Generator[tuple[Point, ...], None, None]: + for d in self.type2nodes[Direction]: + for l1, l2 in utils.perm2(d.neighbors(Line)): + for a, b, c, d in utils.all_4points(l1, l2): + yield a, b, c, d + + def all_perps(self) -> Generator[tuple[Point, ...], None, None]: + for ang in self.vhalfpi.neighbors(Angle): + d1, d2 = ang.directions + if d1 is None or d2 is None: + continue + if d1 == d2: + continue + for l1, l2 in utils.cross(d1.neighbors(Line), d2.neighbors(Line)): + for a, b, c, d in utils.all_4points(l1, l2): + yield a, b, c, d + + def all_congs(self) -> Generator[tuple[Point, ...], None, None]: + for l in self.type2nodes[Length]: + for s1, s2 in utils.perm2(l.neighbors(Segment)): + (a, b), (c, d) = s1.points, s2.points + for x, y in [(a, b), (b, a)]: + for m, n in [(c, d), (d, c)]: + yield x, y, m, n + + def all_eqratios_8points(self) -> Generator[tuple[Point, ...], None, None]: + """List all sets of 8 points that make two equal ratios.""" + ratss = [] + for value in self.type2nodes[Value]: + rats = value.neighbors(Ratio) + ratss.append(rats) + + # include the rats that do not have any val. + ratss.extend([[rat] for rat in self.type2nodes[Ratio] if rat.val is None]) + + seg_pairss = [] + for rats in ratss: + seg_pairs = set() + for rat in rats: + l1, l2 = rat.lengths + if l1 is None or l2 is None: + continue + s1s = l1.neighbors(Segment) + s2s = l2.neighbors(Segment) + seg_pairs.update(utils.cross(s1s, s2s)) + seg_pairss.append(seg_pairs) + + # include (l1, l2) in which l1 does not have any ratio. + norat_ls = [l for l in self.type2nodes[Length] if not l.neighbors(Ratio)] + + for l1 in norat_ls: + for l2 in self.type2nodes[Length]: + if l1 == l2: + continue + s1s = l1.neighbors(Segment) + s2s = l2.neighbors(Segment) + if len(s1s) < 2 and len(s2s) < 2: + continue + seg_pairss.append(set(utils.cross(s1s, s2s))) + seg_pairss.append(set(utils.cross(s2s, s1s))) + + # include Seg that does not have any Length. + nolen_ss = [s for s in self.type2nodes[Segment] if s.val is None] + + for seg in nolen_ss: + for l in self.type2nodes[Length]: + s1s = l.neighbors(Segment) + if len(s1s) == 1: + continue + s2s = [seg] + seg_pairss.append(set(utils.cross(s1s, s2s))) + seg_pairss.append(set(utils.cross(s2s, s1s))) + + record = set() + for seg_pairs in seg_pairss: + for pair1, pair2 in utils.perm2(list(seg_pairs)): + (s1, s2), (s3, s4) = pair1, pair2 + if s1 == s2 or s3 == s4: + continue + if (s1, s2) == (s3, s4): + continue + if (s1, s2, s3, s4) in record: + continue + record.add((s1, s2, s3, s4)) + a, b = s1.points + c, d = s2.points + e, f = s3.points + g, h = s4.points + + for x, y in [(a, b), (b, a)]: + for z, t in [(c, d), (d, c)]: + for m, n in [(e, f), (f, e)]: + for p, q in [(g, h), (h, g)]: + yield (x, y, z, t, m, n, p, q) + + segss = [] + # finally the list of ratios that is equal to 1.0 + for length in self.type2nodes[Length]: + segs = length.neighbors(Segment) + segss.append(segs) + + segs_pair = list(utils.perm2(list(segss))) + segs_pair += list(zip(segss, segss)) + for segs1, segs2 in segs_pair: + for s1, s2 in utils.perm2(list(segs1)): + for s3, s4 in utils.perm2(list(segs2)): + if (s1, s2) == (s3, s4) or (s1, s3) == (s2, s4): + continue + if (s1, s2, s3, s4) in record: + continue + record.add((s1, s2, s3, s4)) + a, b = s1.points + c, d = s2.points + e, f = s3.points + g, h = s4.points + + for x, y in [(a, b), (b, a)]: + for z, t in [(c, d), (d, c)]: + for m, n in [(e, f), (f, e)]: + for p, q in [(g, h), (h, g)]: + yield (x, y, z, t, m, n, p, q) + + def all_eqratios_6points(self) -> Generator[tuple[Point, ...], None, None]: + """List all sets of 6 points that make two equal angles.""" + record = set() + for a, b, c, d, e, f, g, h in self.all_eqratios_8points(): + if ( + a not in (c, d) + and b not in (c, d) + or e not in (g, h) + and f not in (g, h) + ): + continue + if b in (c, d): + a, b = b, a + if f in (g, h): + e, f = f, e + if a == d: + c, d = d, c + if e == h: + g, h = h, g + if (a, b, c, d, e, f, g, h) in record: + continue + record.add((a, b, c, d, e, f, g, h)) + yield a, b, c, d, e, f, g, h # now a==c, e==g + + def all_cyclics(self) -> Generator[tuple[Point, ...], None, None]: + for c in self.type2nodes[Circle]: + for x, y, z, t in utils.perm4(c.neighbors(Point)): + yield x, y, z, t + + def all_colls(self) -> Generator[tuple[Point, ...], None, None]: + for l in self.type2nodes[Line]: + for x, y, z in utils.perm3(l.neighbors(Point)): + yield x, y, z + + def all_midps(self) -> Generator[tuple[Point, ...], None, None]: + for l in self.type2nodes[Line]: + for a, b, c in utils.perm3(l.neighbors(Point)): + if self.check_cong([a, b, a, c]): + yield a, b, c + + def all_circles(self) -> Generator[tuple[Point, ...], None, None]: + for l in self.type2nodes[Length]: + p2p = defaultdict(list) + for s in l.neighbors(Segment): + a, b = s.points + p2p[a].append(b) + p2p[b].append(a) + for p, ps in p2p.items(): + if len(ps) >= 3: + for a, b, c in utils.perm3(ps): + yield p, a, b, c + + def two_points_on_direction(self, d: Direction) -> tuple[Point, Point]: + l = d.neighbors(Line)[0] + p1, p2 = l.neighbors(Point)[:2] + return p1, p2 + + def two_points_of_length(self, l: Length) -> tuple[Point, Point]: + s = l.neighbors(Segment)[0] + p1, p2 = s.points + return p1, p2 + + +def create_consts_str(g: Graph, s: str) -> Union[Ratio, Angle]: + if 'pi/' in s: + n, d = s.split('pi/') + n, d = int(n), int(d) + p0, _ = g.get_or_create_const_ang(n, d) + else: + n, d = s.split('/') + n, d = int(n), int(d) + p0, _ = g.get_or_create_const_rat(n, d) + return p0 + + +def create_consts(g: Graph, p: gm.Node) -> Union[Ratio, Angle]: + if isinstance(p, Angle): + n, d = p.name.split('pi/') + n, d = int(n), int(d) + p0, _ = g.get_or_create_const_ang(n, d) + if isinstance(p, Ratio): + n, d = p.name.split('/') + n, d = int(n), int(d) + p0, _ = g.get_or_create_const_rat(n, d) + return p0 # pylint: disable=undefined-variable diff --git a/graph_test.py b/graph_test.py new file mode 100644 index 0000000..ea7e213 --- /dev/null +++ b/graph_test.py @@ -0,0 +1,164 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for graph.py.""" +import unittest + +from absl.testing import absltest +import graph as gh +import numericals as nm +import problem as pr + + +MAX_LEVEL = 1000 + + +class GraphTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + + cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) + cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) + + # load a complex setup: + txt = 'a b c = triangle a b c; h = orthocenter a b c; h1 = foot a b c; h2 = foot b c a; h3 = foot c a b; g1 g2 g3 g = centroid g1 g2 g3 g a b c; o = circle a b c ? coll h g o' # pylint: disable=line-too-long + p = pr.Problem.from_txt(txt, translate=False) + cls.g, _ = gh.Graph.build_problem(p, GraphTest.defs) + + def test_build_graph_points(self): + g = GraphTest.g + + all_points = g.all_points() + all_names = [p.name for p in all_points] + self.assertCountEqual( + all_names, + ['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3'], + ) + + def test_build_graph_predicates(self): + gr = GraphTest.g + + a, b, c, g, h, o, g1, g2, g3, h1, h2, h3 = gr.names2points( + ['a', 'b', 'c', 'g', 'h', 'o', 'g1', 'g2', 'g3', 'h1', 'h2', 'h3'] + ) + + # Explicit statements: + self.assertTrue(gr.check_cong([b, g1, g1, c])) + self.assertTrue(gr.check_cong([c, g2, g2, a])) + self.assertTrue(gr.check_cong([a, g3, g3, b])) + self.assertTrue(gr.check_perp([a, h1, b, c])) + self.assertTrue(gr.check_perp([b, h2, c, a])) + self.assertTrue(gr.check_perp([c, h3, a, b])) + self.assertTrue(gr.check_cong([o, a, o, b])) + self.assertTrue(gr.check_cong([o, b, o, c])) + self.assertTrue(gr.check_cong([o, a, o, c])) + self.assertTrue(gr.check_coll([a, g, g1])) + self.assertTrue(gr.check_coll([b, g, g2])) + self.assertTrue(gr.check_coll([g1, b, c])) + self.assertTrue(gr.check_coll([g2, c, a])) + self.assertTrue(gr.check_coll([g3, a, b])) + self.assertTrue(gr.check_perp([a, h, b, c])) + self.assertTrue(gr.check_perp([b, h, c, a])) + + # These are NOT part of the premises: + self.assertFalse(gr.check_perp([c, h, a, b])) + self.assertFalse(gr.check_coll([c, g, g3])) + + # These are automatically inferred by the graph datastructure: + self.assertTrue(gr.check_eqangle([a, h1, b, c, b, h2, c, a])) + self.assertTrue(gr.check_eqangle([a, h1, b, h2, b, c, c, a])) + self.assertTrue(gr.check_eqratio([b, g1, g1, c, c, g2, g2, a])) + self.assertTrue(gr.check_eqratio([b, g1, g1, c, o, a, o, b])) + self.assertTrue(gr.check_para([a, h, a, h1])) + self.assertTrue(gr.check_para([b, h, b, h2])) + self.assertTrue(gr.check_coll([a, h, h1])) + self.assertTrue(gr.check_coll([b, h, h2])) + + def test_enumerate_colls(self): + g = GraphTest.g + + for a, b, c in g.all_colls(): + self.assertTrue(g.check_coll([a, b, c])) + self.assertTrue(nm.check_coll([a.num, b.num, c.num])) + + def test_enumerate_paras(self): + g = GraphTest.g + + for a, b, c, d in g.all_paras(): + self.assertTrue(g.check_para([a, b, c, d])) + self.assertTrue(nm.check_para([a.num, b.num, c.num, d.num])) + + def test_enumerate_perps(self): + g = GraphTest.g + + for a, b, c, d in g.all_perps(): + self.assertTrue(g.check_perp([a, b, c, d])) + self.assertTrue(nm.check_perp([a.num, b.num, c.num, d.num])) + + def test_enumerate_congs(self): + g = GraphTest.g + + for a, b, c, d in g.all_congs(): + self.assertTrue(g.check_cong([a, b, c, d])) + self.assertTrue(nm.check_cong([a.num, b.num, c.num, d.num])) + + def test_enumerate_eqangles(self): + g = GraphTest.g + + for a, b, c, d, x, y, z, t in g.all_eqangles_8points(): + self.assertTrue(g.check_eqangle([a, b, c, d, x, y, z, t])) + self.assertTrue( + nm.check_eqangle( + [a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num] + ) + ) + + def test_enumerate_eqratios(self): + g = GraphTest.g + + for a, b, c, d, x, y, z, t in g.all_eqratios_8points(): + self.assertTrue(g.check_eqratio([a, b, c, d, x, y, z, t])) + self.assertTrue( + nm.check_eqratio( + [a.num, b.num, c.num, d.num, x.num, y.num, z.num, t.num] + ) + ) + + def test_enumerate_cyclics(self): + g = GraphTest.g + + for a, b, c, d, x, y, z, t in g.all_cyclics(): + self.assertTrue(g.check_cyclic([a, b, c, d, x, y, z, t])) + self.assertTrue(nm.check_cyclic([a.num, b.num, c.num, d.num])) + + def test_enumerate_midps(self): + g = GraphTest.g + + for a, b, c in g.all_midps(): + self.assertTrue(g.check_midp([a, b, c])) + self.assertTrue(nm.check_midp([a.num, b.num, c.num])) + + def test_enumerate_circles(self): + g = GraphTest.g + + for a, b, c, d in g.all_circles(): + self.assertTrue(g.check_circle([a, b, c, d])) + self.assertTrue(nm.check_circle([a.num, b.num, c.num, d.num])) + + +if __name__ == '__main__': + absltest.main() diff --git a/graph_utils.py b/graph_utils.py new file mode 100644 index 0000000..f53214b --- /dev/null +++ b/graph_utils.py @@ -0,0 +1,132 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilizations for graph representation. + +Mainly for listing combinations and permutations of elements. +""" + +from geometry import Point + + +def _cross(elems1, elems2): + for e1 in elems1: + for e2 in elems2: + yield e1, e2 + + +def cross(elems1, elems2): + return list(_cross(elems1, elems2)) + + +def _comb2(elems): + if len(elems) < 2: + return + for i, e1 in enumerate(elems[:-1]): + for e2 in elems[i + 1 :]: + yield e1, e2 + + +def comb2(elems): + return list(_comb2(elems)) + + +def _comb3(elems): + if len(elems) < 3: + return + for i, e1 in enumerate(elems[:-2]): + for j, e2 in enumerate(elems[i + 1 : -1]): + for e3 in elems[i + j + 2 :]: + yield e1, e2, e3 + + +def comb3(elems): + return list(_comb3(elems)) + + +def _comb4(elems): + if len(elems) < 4: + return + for i, e1 in enumerate(elems[:-3]): + for j, e2 in enumerate(elems[i + 1 : -2]): + for e3, e4 in _comb2(elems[i + j + 2 :]): + yield e1, e2, e3, e4 + + +def comb4(elems): + return list(_comb4(elems)) + + +def _perm2(elems): + for e1, e2 in comb2(elems): + yield e1, e2 + yield e2, e1 + + +def perm2(elems): + return list(_perm2(elems)) + + +def _all_4points(l1, l2): + p1s = l1.neighbors(Point) + p2s = l2.neighbors(Point) + for a, b in perm2(p1s): + for c, d in perm2(p2s): + yield a, b, c, d + + +def all_4points(l1, l2): + return list(_all_4points(l1, l2)) + + +def _all_8points(l1, l2, l3, l4): + for a, b, c, d in all_4points(l1, l2): + for e, f, g, h in all_4points(l3, l4): + yield (a, b, c, d, e, f, g, h) + + +def all_8points(l1, l2, l3, l4): + return list(_all_8points(l1, l2, l3, l4)) + + +def _perm3(elems): + for x in elems: + for y in elems: + if y == x: + continue + for z in elems: + if z not in (x, y): + yield x, y, z + + +def perm3(elems): + return list(_perm3(elems)) + + +def _perm4(elems): + for x in elems: + for y in elems: + if y == x: + continue + for z in elems: + if z in (x, y): + continue + for t in elems: + if t not in (x, y, z): + yield x, y, z, t + + +def perm4(elems): + return list(_perm4(elems)) diff --git a/graph_utils_test.py b/graph_utils_test.py new file mode 100644 index 0000000..4e77238 --- /dev/null +++ b/graph_utils_test.py @@ -0,0 +1,145 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for graph_utils.py.""" +import unittest + +from absl.testing import absltest +import graph_utils as gu + + +class GraphUtilsTest(unittest.TestCase): + + def test_cross(self): + self.assertEqual(gu.cross([], [1]), []) + self.assertEqual(gu.cross([1], []), []) + self.assertEqual(gu.cross([1], [2]), [(1, 2)]) + self.assertEqual(gu.cross([1], [2, 3]), [(1, 2), (1, 3)]) + + e1 = [1, 2, 3] + e2 = [4, 5] + target = [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] + self.assertEqual(gu.cross(e1, e2), target) + + def test_comb2(self): + self.assertEqual(gu.comb2([]), []) + self.assertEqual(gu.comb2([1]), []) + self.assertEqual(gu.comb2([1, 2]), [(1, 2)]) + self.assertEqual(gu.comb2([1, 2, 3]), [(1, 2), (1, 3), (2, 3)]) + + def test_comb3(self): + self.assertEqual(gu.comb3([]), []) + self.assertEqual(gu.comb3([1]), []) + self.assertEqual(gu.comb3([1, 2]), []) + self.assertEqual(gu.comb3([1, 2, 3]), [(1, 2, 3)]) + self.assertEqual( + gu.comb3([1, 2, 3, 4]), [(1, 2, 3), (1, 2, 4), (1, 3, 4), (2, 3, 4)] + ) + + def test_comb4(self): + self.assertEqual(gu.comb4([]), []) + self.assertEqual(gu.comb4([1]), []) + self.assertEqual(gu.comb4([1, 2]), []) + self.assertEqual(gu.comb4([1, 2, 3]), []) + self.assertEqual(gu.comb4([1, 2, 3, 4]), [(1, 2, 3, 4)]) + self.assertEqual( + gu.comb4([1, 2, 3, 4, 5]), + [(1, 2, 3, 4), (1, 2, 3, 5), (1, 2, 4, 5), (1, 3, 4, 5), (2, 3, 4, 5)], + ) + + def test_perm2(self): + self.assertEqual(gu.perm2([]), []) + self.assertEqual(gu.perm2([1]), []) + self.assertEqual(gu.perm2([1, 2]), [(1, 2), (2, 1)]) + self.assertEqual( + gu.perm2([1, 2, 3]), [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2)] + ) + + def test_perm3(self): + self.assertEqual(gu.perm3([]), []) + self.assertEqual(gu.perm3([1]), []) + self.assertEqual(gu.perm3([1, 2]), []) + self.assertEqual( + gu.perm3([1, 2, 3]), + [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)], + ) + self.assertEqual( + gu.perm3([1, 2, 3, 4]), + [ + (1, 2, 3), + (1, 2, 4), + (1, 3, 2), + (1, 3, 4), + (1, 4, 2), + (1, 4, 3), + (2, 1, 3), + (2, 1, 4), + (2, 3, 1), + (2, 3, 4), + (2, 4, 1), + (2, 4, 3), + (3, 1, 2), + (3, 1, 4), + (3, 2, 1), + (3, 2, 4), + (3, 4, 1), + (3, 4, 2), + (4, 1, 2), + (4, 1, 3), + (4, 2, 1), + (4, 2, 3), + (4, 3, 1), + (4, 3, 2), + ], + ) + + def test_perm4(self): + self.assertEqual(gu.perm3([]), []) + self.assertEqual(gu.perm3([1]), []) + self.assertEqual(gu.perm3([1, 2]), []) + self.assertEqual(gu.perm4([1, 2, 3]), []) + self.assertEqual( + gu.perm4([1, 2, 3, 4]), + [ + (1, 2, 3, 4), + (1, 2, 4, 3), + (1, 3, 2, 4), + (1, 3, 4, 2), + (1, 4, 2, 3), + (1, 4, 3, 2), # pylint: disable=line-too-long + (2, 1, 3, 4), + (2, 1, 4, 3), + (2, 3, 1, 4), + (2, 3, 4, 1), + (2, 4, 1, 3), + (2, 4, 3, 1), # pylint: disable=line-too-long + (3, 1, 2, 4), + (3, 1, 4, 2), + (3, 2, 1, 4), + (3, 2, 4, 1), + (3, 4, 1, 2), + (3, 4, 2, 1), # pylint: disable=line-too-long + (4, 1, 2, 3), + (4, 1, 3, 2), + (4, 2, 1, 3), + (4, 2, 3, 1), + (4, 3, 1, 2), + (4, 3, 2, 1), + ], # pylint: disable=line-too-long + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/imo_ag_30.txt b/imo_ag_30.txt new file mode 100644 index 0000000..cec40d6 --- /dev/null +++ b/imo_ag_30.txt @@ -0,0 +1,60 @@ +translated_imo_2000_p1 +a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q +translated_imo_2000_p6 +a b c = triangle a b c; h = orthocenter h a b c; t1 t2 t3 i = incenter2 t1 t2 t3 i a b c; h1 = foot h1 a b c; h2 = foot h2 b c a; h3 = foot h3 c a b; x1 = reflect x1 h1 t1 t2; x2 = reflect x2 h2 t1 t2; y2 = reflect y2 h2 t2 t3; y3 = reflect y3 h3 t2 t3; z = on_line z x1 x2, on_line z y2 y3 ? cong i z i t1 +translated_imo_2002_p2a +b c = segment b c; o = midpoint o b c; a = on_circle a o b; d = on_circle d o b, on_bline d a b; e = on_bline e o a, on_circle e o b; f = on_bline f o a, on_circle f o b; j = on_pline j o a d, on_line j a c ? eqangle e c e j e j e f +translated_imo_2002_p2b +b c = segment b c; o = midpoint o b c; a = on_circle a o b; d = on_circle d o b, on_bline d a b; e = on_bline e o a, on_circle e o b; f = on_bline f o a, on_circle f o b; j = on_pline j o a d, on_line j a c ? eqangle c e c j c j c f +translated_imo_2003_p4 +a b c = triangle a b c; o = circle o a b c; b1 = on_circle b1 o a, on_bline b1 c a; d1 = on_circle d1 o a, on_bline d1 c a; x = on_line x b b1, on_line x a c; d = on_line d d1 x, on_circle d o a; p = foot p d b c; q = foot q d c a; r = foot r d a b ? cong p q q r +translated_imo_2004_p1 +a b c = triangle a b c; o = midpoint o b c; m = on_circle m o b, on_line m a b; n = on_circle n o b, on_line n a c; r = angle_bisector r b a c, angle_bisector r m o n; o1 = circle o1 b m r; o2 = circle o2 c n r; p = on_circle p o1 r, on_circle p o2 r ? coll p b c +translated_imo_2004_p5 +a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; p = on_aline p b c a b d, on_aline p d c a d b ? cong a p c p +translated_imo_2005_p5 +a b c = triangle a b c; d = eqdistance d a b c; e = on_line e b c; f = on_line f a d, eqdistance f d e b; p = on_line p a c, on_line p b d; q = on_line q e f, on_line q b d; r = on_line r e f, on_line r a c; o1 = circle o1 a p d; o2 = circle o2 b p c; m = on_circle m o1 p, on_circle m o2 p ? cyclic p q r m +translated_imo_2007_p4 +a b c = triangle a b c; o = circle o a b c; r = on_circle r o a, on_bline r a b; l = midpoint l c a; k = midpoint k c b; p = on_line p o k, on_line p c r; q = on_line q o l, on_line q c r; l1 = foot l1 l c r; k1 = foot k1 k c r ? eqratio k k1 l l1 r q r p +translated_imo_2008_p1a +a b c = triangle a b c; h = orthocenter h a b c; d = midpoint d b c; e = midpoint e a c; f = midpoint f a b; a1 = on_circle a1 d h, on_line a1 b c; a2 = on_circle a2 d h, on_line a2 b c; b1 = on_circle b1 e h, on_line b1 c a; b2 = on_circle b2 e h, on_line b2 c a; c1 = on_circle c1 f h, on_line c1 a b; c2 = on_circle c2 f h, on_line c2 a b ? cyclic c1 c2 b1 b2 +translated_imo_2008_p1b +a b c = triangle a b c; h = orthocenter h a b c; d = midpoint d b c; e = midpoint e a c; f = midpoint f a b; a1 = on_circle a1 d h, on_line a1 b c; a2 = on_circle a2 d h, on_line a2 b c; b1 = on_circle b1 e h, on_line b1 c a; b2 = on_circle b2 e h, on_line b2 c a; c1 = on_circle c1 f h, on_line c1 a b; c2 = on_circle c2 f h, on_line c2 a b ? cyclic c1 c2 b1 a1 +translated_imo_2008_p6 +x@4.96_-0.13 y@-1.0068968328888160_-1.2534881080682770 z@-2.8402847238575120_-4.9117762734006830 = triangle x y z; o = circle o x y z; w@6.9090049230038776_-1.3884003936987552 = on_circle w o x; a = on_tline a z o z, on_tline a x o x; b = on_tline b z o z, on_tline b w o w; c = on_tline c y o y, on_tline c w o w; d = on_tline d x o x, on_tline d y o y; i1 = incenter i1 a b c; i2 = incenter i2 a c d; f1 = foot f1 i1 a c; f2 = foot f2 i2 a c; q t p s = cc_tangent q t p s i1 f1 i2 f2; k = on_line k q t, on_line k p s ? cong o k o x +translated_imo_2009_p2 +m l k = triangle m l k; w = circle w m l k; q = on_tline q m w m; p = mirror p q m; b = mirror b p k; c = mirror c q l; a = on_line a b q, on_line a c p; o = circle o a b c ? cong o p o q +translated_imo_2010_p2 +a b c = triangle a b c; o = circle o a b c; i = incenter i a b c; d = on_line d a i, on_circle d o a; f = on_line f b c; e = on_aline e a c b a f, on_circle e o a; g = midpoint g i f; k = on_line k d g, on_line k e i ? cong o a o k +translated_imo_2010_p4 +s c p = iso_triangle s c p; o = on_tline o c s c; a = on_circle a o c; b = on_circle b o c, on_line b s a; m = on_line m c p, on_circle m o c; l = on_line l b p, on_circle l o c; k = on_line k a p, on_circle k o c ? cong m k m l +translated_imo_2011_p6 +a b c = triangle a b c; o = circle o a b c; p = on_circle p o a; q = on_tline q p o p; pa = reflect pa p b c; pb = reflect pb p c a; pc = reflect pc p a b; qa = reflect qa q b c; qb = reflect qb q c a; qc = reflect qc q a b; a1 = on_line a1 pb qb, on_line a1 pc qc; b1 = on_line b1 pa qa, on_line b1 pc qc; c1 = on_line c1 pa qa, on_line c1 pb qb; o1 = circle o1 a1 b1 c1; x = on_circle x o a, on_circle x o1 a1 ? coll x o o1 +translated_imo_2012_p1 +a b c = triangle a b c; m l k j = excenter2 m l k j a b c; f = on_line f m l, on_line f b j; g = on_line g m k, on_line g c j; s = on_line s f a, on_line s b c; t = on_line t g a, on_line t c b ? cong m s m t +translated_imo_2012_p5 +c a b = r_triangle c a b; d = foot d c a b; x = on_line x c d; k = on_line k a x, on_circle k b c; l = on_line l b x, on_circle l a c; m = on_line m a l, on_line m b k ? cong m k m l +translated_imo_2013_p4 +a b c = triangle a b c; h = orthocenter h a b c; m = on_line m h b, on_line m a c; n = on_line n h c, on_line n a b; w = on_line w b c; o1 = circle o1 b n w; o2 = circle o2 c m w; x = on_line x o1 w, on_circle x o1 w; y = on_line y o2 w, on_circle y o2 w ? coll x h y +translated_imo_2014_p4 +a b c = triangle a b c; p = on_line p b c, on_aline p a b b c a; q = on_line q b c, on_aline q a c c b a; m = mirror m a p; n = mirror n a q; x = on_line x b m, on_line x c n; o = circle o a b c ? cong o x o a +translated_imo_2015_p3 +a b c = triangle a b c; h = orthocenter h a b c; f = on_line f h a, on_line f b c; m = midpoint m b c; o = circle o a b c; q = on_dia q a h, on_circle q o a; k = on_dia k h q, on_circle k o a; o1 = circle o1 k q h; o2 = circle o2 f k m ? coll o1 o2 k +translated_imo_2015_p4 +a b c = triangle a b c; o = circle o a b c; d = on_line d b c; e = on_line e b c, on_circle e a d; f = on_circle f o a, on_circle f a d; g = on_circle g o a, on_circle g a d; o1 = circle o1 f b d; o2 = circle o2 g c e; k = on_circle k o1 b, on_line k a b; l = on_circle l o2 c, on_line l a c; x = on_line x f k, on_line x l g ? coll x o a +translated_imo_2016_p1 +a b z = triangle a b z; f = angle_bisector f b a z, on_bline f a b; c = on_tline c b f b, on_line c a f; d = on_line d a z, on_bline d a c; e = angle_mirror e c a d, on_bline e a d; m = midpoint m c f; x = parallelogram e a m x; y = on_line y f x, on_line y e m ? coll y b d +translated_imo_2017_p4 +r s = segment r s; t = mirror t r s; o = on_bline o r s; j = on_circle j o s; o1 = circle o1 j s t; a = on_tline a r o r, on_circle a o1 s; b = on_tline b r o r, on_circle b o1 s; k = on_line k j a, on_circle k o s ? perp k t o1 t +translated_imo_2018_p1 +a b c = triangle a b c; o = circle o a b c; d = on_line d a b; e = on_line e a c, on_circle e a d; f = on_bline f b d, on_circle f o a; g = on_bline g e c, on_circle g o a ? para d e f g +translated_imo_2019_p2 +a b c = triangle; a1 = on_line b c; b1 = on_line a c; p = on_line a a1; q = on_line b b1, on_pline p a b; p1 = on_line p b1, eqangle3 p c a b c; q1 = on_line q a1, eqangle3 c q b c a ? cyclic p q p1 q1 +translated_imo_2019_p6 +a b c = triangle a b c; d e f i = incenter2 d e f i a b c; r = on_tline r d e f, on_circle r i d; p = on_line p r a, on_circle p i d; o1 = circle o1 p c e; o2 = circle o2 p b f; q = on_circle q o1 p, on_circle q o2 p; t = on_line t p q, on_line t i d ? perp a t a i +translated_imo_2020_p1 +p a b = triangle p a b; x = angle_bisector p b a; y = angle_bisector p a b; z = on_aline z a p a b x; t = on_aline t p a p a z; d = on_aline d p t p b a, on_line a z; u = on_aline u b p b a y; v = on_aline v p b p b u; c = on_aline c p v p a b, on_line b u; o = angle_bisector a d p, angle_bisector p c b ? cong o a o b +translated_imo_2021_p3 +a b c = triangle; d = angle_bisector b a c; e = on_aline d a d c b, on_line a c; f = on_aline d a d b c, on_line a b; x = on_bline b c, on_line a c; o1 = circle a d c; o2 = circle e x d; y = on_line e f, on_line b c ? coll o1 o2 y +translated_imo_2022_p4 +b c = segment; d = free; e = eqdistance d b c; t = on_bline b d, on_bline c e; a = eqangle2 b t e; p = on_line a b, on_line c d; q = on_line a b, on_line c t; r = on_line a e, on_line c d; s = on_line a e, on_line d t ? cyclic p q r s diff --git a/jgex_ag_231.txt b/jgex_ag_231.txt new file mode 100644 index 0000000..4693316 --- /dev/null +++ b/jgex_ag_231.txt @@ -0,0 +1,462 @@ +examples/complete2/012/complete_004_6_GDD_FULL_81-109_101.gex +a b c = triangle a b c; o = circle o a b c; h = midpoint h c b; d = on_line d o h, on_line d a b; e = on_tline e c c o, on_tline e a a o ? cyclic a o e d +examples/complete2/012/complete_002_6_GDD_FULL_41-60_59.gex +a b c = triangle a b c; m = midpoint m b a; o = circle o a b c; n = on_line n o m, on_circle n o a ? eqangle c a c n c n c b +examples/complete2/012/complete_002_6_GDD_FULL_01-20_04.gex +a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; q = midpoint q c b; s = midpoint s a d; j = midpoint j s q; m = mirror m o j; i = on_line i a d, on_line i b c ? perp s m b c +examples/complete2/012/complete_004_6_GDD_FULL_81-109_90.gex +a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; g = foot g d a b; f = foot f d a c; c1 = on_circle c1 o d, on_line c1 d g; b1 = on_circle b1 o d, on_line b1 d f ? para c1 c b1 b +examples/complete2/012/complete_004_6_GDD_FULL_81-109_94.gex +a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; p = on_circle p o a; f = foot f p a d; g = foot g p a b; h = foot h p b c; e = foot e p c d; i = on_line i f g, on_line i h e ? cyclic p g i h +examples/complete2/012/complete_003_6_GDD_FULL_21-40_37.gex +a b c = triangle a b c; h = orthocenter h a b c; o = circle o a b c; c1 = on_circle c1 o c, on_line c1 c h; a1 = on_circle a1 o a, on_line a1 a h ? cong b a1 b c1 +examples/complete2/012/complete_003_6_GDD_FULL_21-40_22.gex +a b c = triangle a b c; o = circle o a b c; p = foot p o a c; q = foot q o a b; m = on_line m o q, on_circle m o a; n = on_line n o p, on_circle n o a; e = on_line e a c, on_line e n m; d = on_line d a b, on_line d n m ? eqangle d a d e e d e a +examples/complete2/012/complete_001_6_GDD_FULL_01-20_19.gex +a b c = triangle a b c; f = free f; p = circle p a b f; o = circle o a b c; e = on_circle e p a, on_line e a c; d = on_circle d o b, on_line d b f ? para c d e f +examples/complete2/012/complete_001_6_GDD_FULL_61-80_74.gex +a b c = triangle a b c; g = foot g c a b; o = circle o a b c; d = on_circle d o c, on_line d c g; e = foot e d a c; f = foot f d b c ? cyclic a e f b +examples/complete2/013/complete_002_6_GDD_FULL_41-60_49.gex +a b c = triangle a b c; p = midpoint p b a; q = midpoint q c b; d = on_tline d b a c; r = midpoint r d c; s = midpoint s a d; o = on_line o p r, on_line o q s ? cong o s o r +examples/complete2/013/complete_006_Other_ndgTest_70.gex +p a b = triangle p a b; o = midpoint o b a; a1 = on_line a1 p a, on_circle a1 o a; b1 = on_line b1 p b, on_circle b1 o a; o1 = circle o1 p a1 b1 ? perp o a1 a1 o1 +examples/complete2/013/complete_001_6_GDD_FULL_01-20_16.gex +a b o = triangle a b o; m = on_line m a b; p = foot p m a o; q = foot q m b o; d = foot d b a o; c = foot c a b o; t = foot t q a o; k = foot k p b o; s = on_line s q t, on_line s p k ? perp o s p q +examples/complete2/013/complete_001_6_GDD_FULL_61-80_67.gex +m b c = triangle m b c; i = incenter i m b c; i_b = on_tline i_b c c i, on_line i_b b i; i_c = on_tline i_c b b i, on_line i_c c i; a = midpoint a i_b i_c; o = circumcenter o b i c ? perp a b b o +examples/complete2/013/complete_000_2_PWW_A018.gex +o a = segment o a; p = on_circle p a o; q = intersection_cc q a o p; r = lc_tangent r p a, on_circle r o p ? cong p q p r +examples/complete2/013/complete_004_6_GDD_FULL_81-109_88.gex +o x l = triangle o x l; a = foot a l o x; y = free y; b = foot b l o y; p = mirror p l a; q = mirror q l b; q1 = on_line q1 l a, on_line q1 o y; p1 = on_line p1 o x, on_line p1 l b ? cyclic o p q p1 +examples/complete2/013/complete_003_6_GDD_FULL_21-40_24.gex +q r p = triangle q r p; o1 = circle o1 q r p; s = on_circle s o1 q; y = on_line y q s; o = circle o y p q; x = on_circle x o q; i = on_line i r s, on_line i y x ? eqangle i r i x p r p x +examples/complete2/013/complete_003_6_GDD_FULL_21-40_32.gex +b c r = triangle b c r; o = circle o b c r; s = on_circle s o b; a = on_line a b r, on_line a c s; m = foot m a r s; n = foot n a b c ? eqangle a b a m a n a c +examples/complete2/013/complete_002_6_GDD_FULL_41-60_54.gex +a b c = r_triangle a b c; d = foot d a b c; o = midpoint o c b; m = foot m b a o; g = on_line g b m, on_circle g o a; f = on_line f c a, on_line f b m; e = on_line e a d, on_line e b m ? cong e a e b +examples/complete2/013/complete_005_Other_ndg1_53.gex +a o = segment a o; b = on_circle b o a; c = on_line c a b; e = intersection_tt e b b o c c o; d = intersection_lt d c e a a o ? cong o e o d +examples/complete2/013/complete_002_6_GDD_FULL_41-60_56.gex +m a b = iso_triangle m a b; o = circle o a b m; d = on_line d m o, on_line d a b; e = on_tline e a a o, on_pline e m a o ? cong m e m d +examples/complete2/013/complete_002_6_GDD_FULL_41-60_52.gex +e c d = r_triangle e c d; o = midpoint o d c; a = on_tline a c c d, on_tline a e e o; f = on_line f c a, on_line f d e ? cong a e a f +examples/complete2/014/complete_008_7_Book_LLL_L053-1.gex +b c d = triangle b c d; e = foot e b c d; a = free a; f = foot f a c d; g = midpoint g b a ? cong f g g e +examples/complete2/014/complete_007_7_Book_LLL_L058-9.gex +a b = segment a b; c = on_bline c a b; d = on_tline d b a b; e = intersection_lt e c d a a b ? cong e c c d +examples/complete2/007/complete_003_6_GDD_FULL_more_E015-6.gex +a b c = triangle a b c; d = midpoint d a c; e = midpoint e b a; f = midpoint f c b; g = on_pline g d a f, on_pline g f a c ? para c e g b +examples/complete2/007/complete_003_6_GDD_FULL_more_E022-9.gex +a b c = triangle a b c; d = circumcenter d b a c; e = on_line e a c, angle_bisector e c b a; f = intersection_lc f e d b; g = on_tline g f d f ? para f g a c +examples/complete2/007/complete_012_7_Book_00EE_02_E028-2.gex +a b c = triangle a b c; d e = square a c d e; g f = square c b g f ? perp d b f a +examples/complete2/007/complete_012_7_Book_00EE_05_E051-22.gex +a b = segment a b; c = midpoint c b a; d = on_tline d b a b; e = on_line e a d, on_circle e c a; f = on_pline f e a b, on_circle f c e; g = foot g d a f ? cong a f f g +examples/complete2/007/complete_005_Other_other_E075-25-sss.gex +a b c = triangle a b c; d = midpoint d c b; e = midpoint e a c; f = midpoint f b a; g = foot g a b c ? eqangle d f d e g f g e +examples/complete2/007/complete_001_6_GDD_FULL_01-20_01.gex +a b c = triangle a b c; d = foot d c a b; e = foot e b a c; f = midpoint f c b; g = midpoint g d e ? perp f g d e +examples/complete2/007/complete_000_2_PWW_B016x.gex +a b c = triangle a b c; d = midpoint d b c; e = midpoint e c a; f = midpoint f b a; g = parallelogram d a e g ? cong c f g b +examples/complete2/007/complete_001_6_GDD_FULL_61-80_66.gex +a b c = triangle a b c; d = foot d c a b; e = on_tline e a b c, on_line e c d; f = midpoint f a e; g = midpoint g c b ? perp d g d f +examples/complete2/007/complete_016_7_Book_00EE_06_E051-30.gex +a b = segment a b; c = midpoint c b a; d = on_circle d c a; e = lc_tangent e d c, on_line e a b; f = angle_bisector f d e a, on_line f a d; g = on_line g e f, on_line g b d ? cong d f d g +examples/complete2/007/complete_016_7_Book_00EE_06_E051-24.gex +a b = segment a b; c = midpoint c b a; d = on_circle d c a; e = on_line e b d; f = circle f d c e; g = on_pline g e a b, on_circle g f c ? cong g e c b +examples/complete2/007/complete_013_7_Book_00EE_11_E077-37.gex +a b = segment a b; c = midpoint c b a; d = on_circle d c b; e = foot e d a b; f = lc_tangent f d c; g = on_line g d f ? eqangle d f d a d a d e +examples/complete2/007/complete_013_7_Book_00EE_07_E059-54-1.gex +a b = segment a b; c = midpoint c b a; d = mirror d c b; e = on_circle e c a, on_circle e b c; f = on_tline f b a b, on_line f a e; g = on_line g b f, on_line g d e ? cong e g g f +examples/complete2/007/complete_008_ex-gao_ex160_e201f.gex +a b c = triangle a b c; d e = square a b d e; f = foot f a b c; g = on_line g a f, eqdistance g a b c ? perp g b c d +examples/complete2/000/complete_016_ex-gao_gao_M_M020-52.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = on_circle e a b, on_pline e d b c ? cong d c e b +examples/complete2/000/complete_010_Other_gao_L_L190-7.gex +a b = segment a b; c = nsquare c b a; d = psquare d a b; e = on_line e b d; f = foot f e b c; g = foot g e d c ? cong e a g f +examples/complete2/000/complete_007_7_Book_LLL_L017-11.gex +c a = segment c a; b = on_tline b c c a; d = foot d c a b ? eqangle a c a d c b c d +examples/complete2/000/complete_016_ex-gao_gao_M_M024-94.gex +a b c = triangle a b c; d = on_circle d a b, on_circle d c b; e = on_line e d a, on_circle e a d; f = on_line f d c, on_circle f c d ? coll e b f +examples/complete2/000/complete_007_7_Book_LLL_L054-2-1.gex +a b = segment a b; c = on_bline c a b; d = on_line d a c; e = eqdistance e b a d, on_line e b c; f = on_line f a b, on_line f d e; g = on_line g b c, on_pline g d a b ? cong d f f e +examples/complete2/000/complete_007_7_Book_LLL_L057-1-1.gex +a b = segment a b; m = midpoint m a b; c = on_circle c m a; d = angle_mirror d a b c; e = midpoint e a d; f = on_line f b d, on_line f a c ? para c e b d +examples/complete2/000/complete_016_ex-gao_gao_M_M021-64.gex +a b = segment a b; d = midpoint d b a; c = on_circle c b a, on_circle c a b; e = on_line e a c, on_circle e d a; f = on_circle f d a, on_line f b c ? cong a e e f +examples/complete2/000/complete_007_7_Book_LLL_L057-3-2.gex +a b = segment a b; c = on_bline c a b; e = midpoint e a c; d = on_circle d a c, on_line d a c; f = midpoint f b d ? cong b e b f +examples/complete2/000/complete_004_6_GDD_FULL_81-109_95.gex +a b c = triangle a b c; a1 = midpoint a1 c b; f = circle f a b c; s = on_aline s a b c a a1, on_line s b c; p = on_circle p f a, on_line p a a1; q = on_circle q f a, on_line q a s ? para b c p q +examples/complete2/000/complete_001_6_GDD_FULL_01-20_02.gex +a b c = triangle a b c; a1 = midpoint a1 c b; b1 = midpoint b1 c a; c1 = midpoint c1 b a; o = circle o a b c ? perp o a1 b1 c1 +examples/complete2/000/complete_004_6_GDD_FULL_81-109_96.gex +a b c = triangle a b c; a1 = midpoint a1 c b; n = on_line n a a1; g = foot g n a b; h = foot h n a c; s = on_aline s a b c a a1, on_line s b c ? perp g h a s +examples/complete2/000/complete_007_7_Book_LLL_L194-2.gex +a b = segment a b; c = on_bline c a b; d = on_bline d a b; e = on_line e c d, on_line e a b ? cong a e e b +examples/complete2/000/complete_017_ex-gao_gao_L_L022-1.gex +a b = segment a b; c = on_bline c a b; d = on_line d a c; e = on_circle e c d, on_line e b c ? cong a e b d +examples/complete2/000/complete_016_ex-gao_gao_M_M09-14.gex +c b = segment c b; d = midpoint d c b; a = free a; e = midpoint e b a; f = midpoint f c a; g = on_line g a d, on_line g e f ? cong e g g f +examples/complete2/009/complete_014_7_Book_00EE_09_E071-4.gex +a b = segment a b; c = midpoint c b a; d = on_circle d c a; e = lc_tangent e d c, on_line e a b; f = foot f a d e ? eqangle a f a d a d a b +examples/complete2/009/complete_013_7_Book_00EE_10_E072-13.gex +a b c = triangle a b c; d = foot d b a c; e = foot e a b c; f = foot f b d e ? eqangle b a b d b c b f +examples/complete2/009/complete_014_7_Book_00EE_09_E071-2.gex +a b c = triangle a b c; d = midpoint d c b; e = foot e b a c; f = foot f c a b ? eqangle a b a c e f e d +examples/complete2/009/complete_014_7_Book_00EE_09_E071-1.gex +a b c = triangle a b c; e = on_line e a b, on_circle e a c; d = angle_bisector d b a c, on_line d b c; f = on_pline f e b c, on_line f a c ? eqangle e d e c e c e f +examples/complete2/009/complete_017_ex-gao_ex160_4_e10.gex +a b c d = isquare a b c d; e = on_line e b d, on_circle e b c; f = on_tline f e b d, on_line f d c ? cong e d c f +examples/complete2/009/complete_003_6_GDD_FULL_more_E022-12.gex +a b c = triangle a b c; e = circumcenter e a b c; d = on_line d a b, angle_bisector d a c b; f = on_tline f c c e, on_pline f d a c ? cong c f d b +examples/complete2/009/complete_001_6_GDD_FULL_61-80_69.gex +d a b = r_triangle d a b; c = midpoint c b a; e = circle e a c d; f = circle f b d c ? perp e d d f +examples/complete2/009/complete_012_7_Book_00EE_05_E051-19.gex +a b c = triangle a b c; d = circumcenter d a c b; e = on_line e b c; f = on_circle f d a, angle_bisector f a c e ? cong a f f b +examples/complete2/009/complete_016_7_Book_00EE_06_E051-32.gex +a b c = triangle a b c; d = eq_triangle d a b; e = eq_triangle e a c; f = eq_triangle f c b ? para e d c f +examples/complete2/009/complete_013_7_Book_00EE_10_E074-23.gex +a b c = triangle a b c; d = foot d a b c; e = circumcenter e b a c; f = angle_bisector f b a c, on_circle f e a ? eqangle e a a f f a a d +examples/complete2/009/complete_011_Other_Auxiliary_aux2_trapezoid.gex +a b c d = trapezoid a b c d; e = midpoint e d a; f = on_pline f e a b, on_line f b c ? midp f b c +examples/complete2/009/complete_016_7_Book_00EE_06_E057-37.gex +a b c = triangle a b c; d = eq_triangle d a b; e = eq_triangle e a c; f = parallelogram c e d f ? cong b f f c +examples/complete2/008/complete_004_6_GDD_FULL_81-109_100.gex +a c = segment a c; b = eq_triangle b c a; e = mirror e c b; d = mirror d b e; f = foot f d a b ? perp a c c f +examples/complete2/008/complete_005_Other_ndgs_02.gex +b a c = triangle b a c; d = foot d b a c; e = foot e c a b; f = intersection_ll f b d c e ? perp b c a f +examples/complete2/008/complete_008_ex-gao_ex160_205.gex +c a b = r_triangle c a b; d = midpoint d c a; f = midpoint f c b; e = on_line e a b, on_circle e d c ? perp d e e f +examples/complete2/008/complete_015_7_Book_00EE_08_E061-62.gex +a b = segment a b; c = on_circle c a b; e = on_circle e a b; d = on_circle d a b, on_circle d b c; f = on_circle f b c, on_line f c e ? cong e d e f +examples/complete2/008/complete_015_7_Book_00EE_06_E051-31.gex +a b c = triangle a b c; d = parallelogram a b c d; e = eq_triangle e a b; f = eq_triangle f b c ? cong d e d f +examples/complete2/008/complete_011_7_Book_00EE_03_E037-22.gex +c a b = risos c a b; e = midpoint e b a; d = on_line d a b, on_circle d b c; f = on_line f a c, on_circle f c e ? perp a c f d +examples/complete2/008/complete_011_7_Book_00EE_03_E037-21.gex +a b = segment a b; c = on_circle c a b; d = lc_tangent d c a, on_line d a b; e = on_line e a b, on_circle e a b; f = on_pline f a c e, on_line f c d ? perp f b a b +examples/complete2/008/complete_011_7_Book_00EE_04_E051-5.gex +c a = segment c a; b = eq_triangle b c a; d = circumcenter d c a b; e = on_pline e d a c, on_line e a b; f = on_pline f d b c, on_line f a b ? cong a e e f +examples/complete2/008/complete_003_6_GDD_FULL_more_E009-1.gex +a c = segment a c; b = on_tline b c a c; d = on_dia d b a, on_circle d a c; e = on_line e b c, on_circle e a b; f = on_line f b d, on_circle f a b ? para c d e f +examples/complete2/008/complete_011_7_Book_00EE_03_E039-28.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b, on_circle d c b; e = mirror e d c; f = on_circle f a b, on_line f b e ? coll d a f +examples/complete2/008/complete_011_7_Book_00EE_03_E040-28-1.gex +c a b = iso_triangle c a b; d = on_line d b c; e = circle e a b d; f = on_circle f e a, on_line f a c ? para a b f d +examples/complete2/008/complete_018_ex-gao_ex160_4_004.gex +b a c = triangle b a c; d = on_line d b c, on_circle d a b; e = on_tline e c a c, on_tline e b a b; f = on_tline f d a d, on_line f c e ? cong e c c f +examples/complete2/008/complete_014_7_Book_00EE_07_E059-50.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = circle e c a d; f = on_line f b c, on_circle f e a ? cong d f f b +examples/complete2/008/complete_013_7_Book_00EE_07_E057-44.gex +c a b = iso_triangle c a b; d = foot d a b c; e = foot e b a c; f = on_line f a d, on_line f b e ? cong f a f b +examples/complete2/001/complete_006_7_Book_LLL_L046-16.gex +a b = segment a b; c = on_line c a b; d = on_circle d c a, on_circle d a c; e = on_aline e b a d c a, on_aline e c a d a b; f = on_line f c d, on_line f a e; g = on_line g b d, on_line g c e ? cong c f c g +examples/complete2/001/complete_016_ex-gao_gao_M_M010-32.gex +b c a = triangle b c a; d = on_pline d a b c, on_pline d c a b; e = on_line e b c; f = on_line f a d, on_pline f e a b; g = on_line g a e, on_line g b f; h = on_line h c f, on_line h d e ? para h g d a +examples/complete2/001/complete_016_ex-gao_gao_M_M010-26.gex +b d = segment b d; e = midpoint e b d; c = free c; a = on_pline a d b c, on_pline a b d c; f = on_line f c d; g = on_line g a b, on_line g e f; h = on_line h e f, on_line h a d; i = on_line i b c, on_line i e f ? cong f h g i +examples/complete2/001/complete_016_ex-gao_gao_C_C101.gex +a b c = triangle a b c; e = foot e a b c; f = foot f c a b; d = on_bline d a c, on_bline d a b; g = on_line g c f, on_line g a e; h = on_line h c f, on_circle h d c ? cong g f f h +examples/complete2/001/complete_016_ex-gao_gao_C_C100.gex +a c = segment a c; b = on_tline b c c a; e = on_circle e b c; d = on_circle d a c, on_circle d b c; f = on_line f c e, on_circle f a c; g = on_line g e b, on_circle g b e ? coll d f g +examples/complete2/001/complete_016_ex-gao_gao_L_L182-6.gex +a b c = triangle a b c; e = midpoint e b c; d = on_line d a b; f = midpoint f d c; g = midpoint g b a; h = midpoint h g f; i = on_line i a b, on_line i e h ? cong a i i d +examples/complete2/001/complete_016_ex-gao_gao_C_C111.gex +a d c = triangle a d c; b = on_pline b a d c; e = on_line e a d; f = on_line f a c, on_pline f e a b; g = on_line g b d, on_line g e f; h = on_line h b c, on_line h e f ? cong e f g h +examples/complete2/001/complete_016_ex-gao_gao_L_L025-5.gex +a b = segment a b; c = on_bline c a b; d = on_line d a c; e = on_circle e c d, on_line e b c; f = on_line f b d, on_line f a e ? eqangle a c c f f c c b +examples/complete2/001/complete_017_ex-gao_gao_L_L189-2.gex +a b = segment a b; c = on_bline c a b; e = midpoint e c a; f = midpoint f b c; d = on_pline d b a c, on_pline d a b c; g = midpoint g d b; h = midpoint h a d ? perp h e e f +examples/complete2/001/complete_016_ex-gao_gao_L_L182-5.gex +c d a = triangle c d a; b = on_pline b c d a, on_pline b a d c; e = on_line e c d; f = on_line f a b, on_pline f c a e; g = on_line g b e, on_line g c f; h = on_line h d f, on_line h a e ? cong g e f h +examples/complete2/001/complete_017_ex-gao_gao_L_L189-1.gex +a b = segment a b; c = on_tline c b a b; d = on_tline d c b c, on_tline d a a b; e = midpoint e c d; f = midpoint f b c; g = midpoint g a b; h = midpoint h a d ? cong h g h e +examples/complete2/001/complete_016_ex-gao_gao_C_C109.gex +b d a = triangle b d a; c = on_pline c d a b; e = on_line e b d, on_line e a c; f = on_line f a d, on_pline f e a b; g = on_line g b c, on_line g e f ? cong f e e g +examples/complete2/001/complete_016_ex-gao_gao_L_LL153-1.gex +c a d = triangle c a d; e = foot e c a d; b = free b; f = foot f b a d; g = midpoint g c b; h = midpoint h e f ? cong g e g f +examples/complete2/001/complete_010_Other_gao_Y_yL182-4.gex +c d = segment c d; e = midpoint e c d; a = free a; b = on_pline b c d a, on_pline b a d c; f = midpoint f a b; g = on_line g a c, on_line g b e; h = on_line h d f, on_line h a c ? cong a h h g +examples/complete2/006/complete_012_7_Book_00EE_02_E028-3.gex +c a b = risos c a b; d = midpoint d b a; e = on_line e b c; f = circle f d b e; g = on_line g a e, on_circle g f b ? perp c g a e +examples/complete2/006/complete_003_6_GDD_FULL_more_E022-11.gex +a b c = triangle a b c; d = circumcenter d a b c; f = foot f d a b; e = on_tline e c c d, on_tline e b b d; g = on_line g d f, on_line g a c ? para g e a b +examples/complete2/006/complete_010_Other_Auxiliary_aux2_e04f.gex +a b c d = trapezoid a b c d; e = midpoint e c a; f = midpoint f d b; g = on_line g e f, on_line g a d ? midp g a d +examples/complete2/006/complete_004_6_GDD_FULL_81-109_98.gex +a b c = triangle a b c; e = on_line e a b; d = circle d a b c; f = on_circle f d a, on_aline f c b a c e; g = on_circle g d c, on_line g c e ? para a b g f +examples/complete2/006/complete_001_6_GDD_FULL_61-80_72.gex +a b c = triangle a b c; d = circle d a b c; e = on_circle e d a; f = foot f e a c; g = foot g e a b ? simtri e f g e c b +examples/complete2/006/complete_013_7_Book_00EE_11_E075-26.gex +a b = segment a b; c = mirror c a b; d = mirror d b c; e = midpoint e c b; f = on_circle f e c, on_dia f a e; g = on_line g a f ? eqangle b f a f c f e f +examples/complete2/006/complete_015_7_Book_00EE_06_E057-38.gex +c a b = r_triangle c a b; d = foot d c a b; e = angle_bisector e c a b, on_line e b c; g = foot g e a b; f = on_line f c d, on_line f a e ? cong c e c f +examples/complete2/006/complete_014_7_Book_00EE_07_E059-47.gex +a b c d = rectangle a b c d; e = on_line e b d, on_line e a c; f = midpoint f e d; g = midpoint g e a ? cong f c g b +examples/complete2/006/complete_014_7_Book_00EE_07_E059-53.gex +a b c = triangle a b c; d = circle d c a b; e = circle e c d b; f = on_line f a b, on_circle f e b; g = on_line g a c, on_circle g e b ? cong g b g a +examples/complete2/006/complete_003_6_GDD_FULL_more_E023-15.gex +a b c d = quadrangle a b c d; e = on_line e a c; g = on_pline g e a b, on_line g b c; f = on_pline f e a d, on_line f c d ? para b d g f +examples/complete2/011/complete_002_6_GDD_FULL_01-20_12.gex +a b c = triangle a b c; o = circle o a b c; d = on_tline d b a c, on_circle d o a; f = midpoint f b a; e = on_line e a c, on_line e b d ? perp f e c d +examples/complete2/011/complete_002_6_GDD_FULL_01-20_05.gex +a b c = triangle a b c; h = orthocenter h a b c; o = circumcenter o a b c; c1 = circumcenter c1 a b h; b1 = circumcenter b1 a h c; a1 = circumcenter a1 b h c ? perp a1 o b1 c1 +examples/complete2/011/complete_003_6_GDD_FULL_21-40_34.gex +a b c = triangle a b c; h = orthocenter h a b c; o = circle o h b c; p = on_tline p h c h, on_circle p o b ? para a h b p +examples/complete2/011/complete_004_6_GDD_FULL_81-109_99.gex +a b c = triangle a b c; m = free m; n = on_aline n a c b a m; q = foot q m a b; p = foot p m a c ? perp a n p q +examples/complete2/011/complete_003_6_GDD_FULL_21-40_35.gex +a b c = triangle a b c; d = foot d c a b; e = foot e b a c; o = circle o a b c; k = on_circle k o c, on_line k c d; h = on_line h c d, on_line h b e ? cong a k a h +examples/complete2/011/complete_003_6_GDD_FULL_21-40_31.gex +a b c = triangle a b c; c1 = midpoint c1 b a; b1 = midpoint b1 c a; o = circle o a b c; p = on_line p o c1, on_line p a c; q = on_line q o b1, on_line q a b ? cyclic q b c p +examples/complete2/011/complete_002_6_GDD_FULL_41-60_41.gex +a b c = triangle a b c; i = incenter i a b c; y = foot y i a c; l = foot l i b c; x = foot x b a i ? coll x y l +examples/complete2/011/complete_002_6_GDD_FULL_41-60_43.gex +c a b = r_triangle c a b; f e = square a b f e; p = on_line p b e, on_line p a f ? eqangle c a c p c p c b +examples/complete2/011/complete_002_6_GDD_FULL_41-60_51.gex +a b c = triangle a b c; o = circle o a b c; d = on_tline d b a c, on_circle d o a; e = on_circle e o d, on_line e d o ? para b e a c +examples/complete2/011/complete_002_6_GDD_FULL_41-60_44.gex +a b c = triangle a b c; o = circle o a b c; d = on_circle d o a; a1 = on_tline a1 a a b, on_line a1 c d; c1 = on_tline c1 c c d, on_line c1 a b ? para d b a1 c1 +examples/complete2/010/complete_004_6_GDD_FULL_21-40_29.gex +a b c = triangle a b c; d = foot d a b c; q = foot q d a b; p = foot p d a c ? cyclic b q p c +examples/complete2/010/complete_002_6_GDD_FULL_01-20_10.gex +a b c d = quadrangle a b c d; e = on_line e b c, on_line e a d; o1 = circle o1 c d e; o = circle o e b a; p = on_line p c d, on_line p a b; q = on_circle q o1 c, on_circle q o a ? cyclic p d q a +examples/complete2/010/complete_013_7_Book_00EE_10_E072-15.gex +b c a = triangle b c a; d = lc_tangent d c a, lc_tangent d b a; e = on_circle e a c, on_dia e a d ? eqangle b e b a b a b c +examples/complete2/010/complete_011_7_Book_00EE_04_E051-6.gex +a b c = triangle a b c; d = eq_triangle d c a; e = eq_triangle e b a ? cong d e c b +examples/complete2/010/complete_012_7_Book_00EE_05_E051-20.gex +a b = segment a b; d = midpoint d a b; c = on_tline c b a b; e = on_line e a c, on_circle e d b; f = lc_tangent f e d, on_line f b c ? cong f c f b +examples/complete2/010/complete_011_7_Book_00EE_03_E037-20.gex +a b = segment a b; c = midpoint c a b; d = on_circle d c a; e = lc_tangent e d c, angle_mirror e b a d ? perp a e e d +examples/complete2/010/complete_012_7_Book_00EE_11_E076-32.gex +c a b = r_triangle c a b; d = midpoint d b c; e = foot e c a d ? eqangle a b b c d e e b +examples/complete2/010/complete_000_3_JAR_JAR02-new_fig214.gex +a b c = triangle a b c; d = intersection_pp d a b c c a b; e = intersection_ll e a c b d ? cong a e e c +examples/complete2/010/complete_003_6_GDD_FULL_more_E021-3.gex +a b = segment a b; c = on_circle c a b; e = intersection_lc e a a c; d = on_tline d c a c, on_tline d b a b ? para a d b e +examples/complete2/010/complete_013_7_Book_00EE_10_E074-22.gex +a b = segment a b; c = on_circle c a b; d = on_line d a c; e = on_line e a b, on_circle e a d; f = on_line f b d, on_line f c e ? eqangle a b a f a f a c +examples/complete2/010/complete_001_6_GDD_FULL_01-20_20.gex +a b c = triangle a b c; d = foot d a b c; e = foot e b a c; h = on_line h a d, on_line h b e; g = foot g h a b ? eqangle g e g h g h g d +examples/complete2/010/complete_002_6_GDD_FULL_41-60_57.gex +a b c = triangle a b c; d = foot d a b c; o = midpoint o a d; e = on_line e a b, on_circle e o d; f = on_line f a c, on_circle f o d ? cyclic b c e f +examples/complete2/010/complete_010_Other_Auxiliary_ye_aux_ppara.gex +a b c d = eq_trapezoid a b c d; e = on_pline e b a d, on_line e c d ? eqangle a d a b b a b c +examples/complete2/003/complete_003_6_GDD_FULL_more_E013-3.gex +a b = segment a b; d = on_tline d b a b; e = on_circle e a b; f = on_line f d e, on_circle f a b; g = midpoint g e f; c = on_circle c a b, on_dia c d a; h = intersection_lc h g a c ? para e f b h +examples/complete2/003/complete_005_Other_ndgs_01.gex +b c a = triangle b c a; d = intersection_cc d b a c; e = on_circle e b c; g = intersection_lc g e a d; f = on_circle f b c; h = intersection_lc h f a c ? para g h e f +examples/complete2/003/complete_013_7_Book_00EE_10_E072-12.gex +b a c = triangle b a c; d = on_circle d a b, on_circle d c b; e = on_tline e d b d, on_circle e a b; f = on_circle f c d, on_line f d e; h = on_circle h a b, on_line h b f; g = on_circle g c b, on_line g b e ? eqangle d h d b d b d g +examples/complete2/003/complete_010_Other_Auxiliary_ye_aux_wang3.gex +a b c d = isquare a b c d; f = angle_bisector f a d b, on_line f a c; g = foot g c d f; e = on_line e b d, on_line e a c; h = on_line h c g, on_line h a d; i = on_line i b d, on_line i c g; x = midpoint x a h ? cong a x e i +examples/complete2/003/complete_003_6_GDD_FULL_more_E022-8.gex +b a c = triangle b a c; d = on_circle d a b, on_circle d c b; e = on_circle e c b; f = intersection_lc f e a d; g = intersection_lc g e a b; h = on_tline h e c e ? para h e g f +examples/complete2/003/complete_008_ex-gao_ex160_206.gex +a b c = triangle a b c; e d = square b a e d; f g = square a c f g; h = on_line h b e, on_line h a d; i = on_pline i e a g, on_pline i g a e ? perp c h h i +examples/complete2/003/complete_013_7_Book_00EE_11_E077-38.gex +a b c = triangle a b c; d = circumcenter d a b c; e = lc_tangent e b d; f = angle_bisector f e b c, on_circle f d a; g = foot g f b c; h = foot h f b e ? eqangle b a a f f a a c +examples/complete2/003/complete_004_6_GDD_FULL_81-109_84.gex +a b c = triangle a b c; d = midpoint d b a; e = midpoint e c b; f = midpoint f a c; g = circle g d e f; h = on_tline h e e g, on_line h a b; i = on_line i e h, on_line i a c ? cyclic b h c i +examples/complete2/003/complete_003_6_GDD_FULL_more_E022-10.gex +a b = segment a b; c = on_circle c a b; e = on_circle e a b; d = on_circle d a b; f = on_line f c e, on_line f b d; g = circumcenter g e f b; h = on_tline h f f g ? para h f c d +examples/complete2/003/complete_011_7_Book_00EE_03_E037-25.gex +a b = segment a b; c = midpoint c b a; d = on_tline d c a b, on_circle d c a; e = on_circle e c d, on_line e d c; f = on_line f a b; g = on_line g c d, on_circle g c f; h = on_circle h c e, on_line h e f; i = on_line i b g, on_circle i c a ? perp e h b i +examples/complete2/003/complete_016_7_Book_00EE_06_E051-25.gex +a b c = triangle a b c; d = foot d b a c; e = foot e c a b; g = circumcenter g b c a; h = intersection_lc h g g a; f = on_line f b d, on_line f c e; i = on_line i b c, on_line i f h ? cong f i i h +examples/complete2/003/complete_013_7_Book_00EE_10_E074-20.gex +a b c = triangle a b c; d = on_line d a b; e = foot e d b c; f = foot f e a b; g = on_line g b c; h = foot h g a b; i = foot i h b c ? eqangle d g d h f i f d +examples/complete2/003/complete_017_ex-gao_ex160_4_003.gex +a b = segment a b; c = on_circle c a b; e = on_line e b c; d = on_tline d c a c, on_tline d b a b; f = on_tline f e a e, on_line f c d; g = on_line g e f, on_circle g a b; h = on_line h e f, on_line h b d ? cong c f h b +examples/complete2/003/complete_015_7_Book_00EE_08_E059-56.gex +a b c = triangle a b c; d = foot d b a c; e = foot e a b c; f = on_line f b d, on_line f a e; g = circle g b f a; h = on_line h b c, on_circle h g a; i = on_line i a c, on_circle i g a ? cong f c f i +examples/complete2/003/complete_014_7_Book_00EE_07_E059-52.gex +a b = segment a b; d = on_circle d a b; c = on_tline c a a b, on_circle c a b; e = on_tline e c a c, on_tline e b a b; f = lc_tangent f d a, on_line f c e; g = on_line g b e, on_line g d f; h = on_line h c e, on_line h b d ? cong h e d g +examples/complete2/004/complete_002_6_GDD_FULL_01-20_13.gex +a b c = triangle a b c; d = parallelogram a b c d; e = foot e b a c; f = foot f a b d; g = foot g d a c; h = foot h c b d ? para e f g h +examples/complete2/004/complete_006_Other_Auxiliary_E092-5.gex +a b c = triangle a b c; d = parallelogram a b c d; e = angle_bisector e d a b, on_dia e a d; f = foot f b a e; g = foot g c b f; h = on_line h d e, on_line h c g ? para e g a b +examples/complete2/004/complete_004_6_GDD_FULL_81-109_86.gex +a b c = triangle a b c; d = midpoint d c a; e = midpoint e b c; f = midpoint f a b; g = angle_bisector g a b c, on_line g d f; h = on_line h b g, on_line h d e ? cong d g d h +examples/complete2/004/complete_011_7_Book_00EE_03_E037-26.gex +a b c d = isquare a b c d; e = on_line e b c; g = on_line g d c, on_line g a e; f = on_line f b d, on_line f a e; h = circle h g e c ? perp f c c h +examples/complete2/004/complete_016_7_Book_00EE_06_E051-27.gex +a b = segment a b; c = midpoint c b a; d = on_circle d c a; e = angle_bisector e d c a, on_circle e c a; f = foot f e a b; g = intersection_ll g a d e f; h = intersection_ll h a d b e ? cong a g g e +examples/complete2/004/complete_001_6_GDD_FULL_61-80_73.gex +a b c = triangle a b c; d = circle d a b c; e = on_circle e d a; f = foot f e a c; g = foot g e a b; h = on_circle h d e, on_line h e g ? para g f h c +examples/complete2/004/complete_014_7_Book_00EE_07_E057-42.gex +a b c = triangle a b c; d = midpoint d a c; e = midpoint e b a; f = midpoint f c b; g = on_line g a b; h = on_pline h d f g, on_line h a b ? cong h a g e +examples/complete2/005/complete_005_Other_ndgs_03.gex +b a c = triangle b a c; d = foot d b a c; e = foot e c a b; f = midpoint f c b; g = foot g f d e ? cong g e g d +examples/complete2/005/complete_000_rebuilt example_9point.gex +a b c = triangle a b c; d = foot d a b c; e = midpoint e b a; f = midpoint f c b; g = midpoint g a c; o = circumcenter o e f g ? cyclic d g e f +examples/complete2/005/complete_013_7_Book_00EE_11_E081-2.gex +a b = segment a b; d = on_circle d a b; c = on_circle c a b, on_circle c b d; e = foot e d b c; f = intersection_lc f e a d; g = on_line g d e, on_circle g e f ? cong f c g d +examples/complete2/005/complete_002_6_GDD_FULL_41-60_58.gex +a b c = triangle a b c; d = circle d a b c; e = on_circle e d a; f = on_line f a b, on_line f c e; g = on_pline g f a e, on_line g b c ? eqangle f g f b c f c b +examples/complete2/005/complete_016_7_Book_00EE_06_E051-26.gex +a b = segment a b; c = on_circle c a b; e = on_line e b c; d = lc_tangent d b a, lc_tangent d c a; f = on_tline f e a e, on_line f c d; g = on_line g e f, on_line g b d ? cong g e e f +examples/complete2/005/complete_001_6_GDD_FULL_61-80_61.gex +a b c = triangle a b c; e = midpoint e b a; d = circle d a b c; f = on_line f d e; g = on_line g b c, on_circle g f a ? simtri a d f a c g +examples/complete2/005/complete_017_ex-gao_ex160_4_e03a_lratio.gex +c a b = iso_triangle c a b; e = midpoint e b c; f = on_line f a b, on_circle f e b; d x = trisegment d x c b; g = on_line g c f, on_line g a d ? cong c g g f +examples/complete2/005/complete_008_ex-gao_ex160_e122.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = on_circle e a b, on_pline e d b c; f = on_circle f a b; g = on_circle g a b, on_pline g d c f ? para g b e f +examples/complete2/002/complete_007_7_Book_LLL_yL251-1.gex +a d b = triangle a d b; c = free c; e = on_line e a d; f = on_line f a b; g = on_line g a c; h = on_line h a d; i = on_line i a c, on_pline i h e g; j = on_line j a b, on_pline j i f g ? para e f h j +examples/complete2/002/complete_017_ex-gao_ex160_4_e12.gex +a b c d = quadrangle a b c d; e = midpoint e c b; f = midpoint f d c; g = midpoint g a d; h = midpoint h b a; i = on_line i e g, on_line i f h ? midp i h f +examples/complete2/002/complete_013_7_Book_00EE_10_E073-18.gex +a b c = triangle a b c; d = parallelogram a b c d; e = on_line e b d; f = foot f e a b; g = foot g e b c; h = on_line h e f, on_line h c d; i = on_line i e g, on_line i a d ? para h i g f +examples/complete2/002/complete_011_7_Book_00EE_03_E043-3.gex +a b = segment a b; c = on_circle c a b; f = on_circle f a b; e = on_line e b c, on_dia e f a; g = intersection_lc g e a f; d = lc_tangent d b a, lc_tangent d c a; h = on_line h c d, on_line h e f; i = on_line i e f, on_line i b d ? cong c h b i +examples/complete2/002/complete_012_7_Book_00EE_02_E028-2-1.gex +a b c = triangle a b c; e d = square a c e d; f g = square c b f g; h = midpoint h a e; i = midpoint i b g; j = midpoint j a b ? perp h j j i +examples/complete2/002/complete_008_ex-gao_ex160_e124.gex +b a c = triangle b a c; e = on_circle e a b; d = on_circle d a b, on_circle d c b; h = on_tline h e a e, on_circle h c b; g = on_circle g c d, on_line g d e; f = on_circle f c b, on_line f b e; i = on_circle i c h, on_line i h e ? para g f h i +examples/complete2/000/complete_004_6_GDD_FULL_81-109_106.gex +q4 q1 q3 q0 q2 = pentagon q4 q1 q3 q0 q2; p0 = on_line p0 q4 q1, on_line p0 q0 q2; p4 = on_line p4 q4 q1, on_line p4 q3 q0; p3 = on_line p3 q3 q0, on_line p3 q4 q2; p2 = on_line p2 q1 q3, on_line p2 q4 q2; p1 = on_line p1 q1 q3, on_line p1 q0 q2; o0 = circle o0 q0 p0 p4; o1 = circle o1 p1 q1 p0; o4 = circle o4 p4 p3 q4; o3 = circle o3 p3 p2 q3; o2 = circle o2 p1 p2 q2; m0 = on_circle m0 o0 q0, on_circle m0 o1 q1; m4 = on_circle m4 o0 q0, on_circle m4 o4 q4; m3 = on_circle m3 o4 q4, on_circle m3 o3 q3; m2 = on_circle m2 o3 q3, on_circle m2 o2 q2; m1 = on_circle m1 o2 q2, on_circle m1 o1 q1 ? cyclic m4 m3 m2 m1 +examples/complete2/unsolved2/complete_010_Other_Auxiliary_ye_aux_think2.gex +c a b = iso_triangle c a b; d = on_line d a c; e = on_line e b c, eqdistance e b d a; f = on_line f a b, on_line f d e; g = on_pline g f a c, on_line g b c ? midp f d e +examples/complete2/unsolved2/complete_012_7_Book_00EE_02_E023-21.gex +a b = segment a b; c = on_tline c b a b; d = on_circle d a b; f = midpoint f c b; g = on_line g d f, on_circle g a b; h = intersection_lc h c a g; e = on_line e c d, on_circle e a b ? para b c h e +examples/complete2/unsolved2/complete_006_7_Book_LLL_yL252-6.gex +a c d = triangle a c d; b = on_pline b c d a, on_pline b a d c; e = on_line e a c; g = on_line g a b, on_pline g e a d; f = on_line f a d, on_pline f e c d; h = on_line h e g, on_line h c d; i = on_line i b c, on_line i e f ? para f g h i +examples/complete2/unsolved2/complete_015_7_Book_00EE_08_E059-59.gex +a b c = triangle a b c; d = angle_bisector d b a c; e = on_pline e c b d, on_pline e b c d; f = on_line f b e, on_line f a c; g = on_line g c e, on_line g a b ? cong b g c f +examples/complete2/unsolved2/complete_013_7_Book_00EE_10_E072-16.gex +a b c = triangle a b c; d = parallelogram a b c d; e = on_line e c d; f = on_line f a d, eqdistance f c a e; g = on_line g a e, on_line g c f ? eqangle g a g b g b g c +examples/complete2/unsolved2/complete_003_6_GDD_FULL_more_E023-19.gex +c a b = r_triangle c a b; d = foot d c a b; e = on_line e c d, angle_bisector e b a c; f = on_line f a b, angle_bisector f d c b; g = on_line g b c, on_line g a e ? para e f c b +examples/complete2/unsolved2/complete_010_Other_Auxiliary_ye_aux_ll43.gex +a b = segment a b; c = on_dia c a b, on_bline c a b; d = midpoint d a c; e = foot e c b d; f = on_line f c e, on_line f a b ? eqangle d c d b d f d a +examples/complete2/unsolved2/complete_010_Other_Auxiliary_aux2_22.gex +c a b = iso_triangle c a b; d = on_line d a c; e = on_line e b c, eqdistance e b a d; f = on_line f a b, on_line f d e ? cong d f e f +examples/complete2/unsolved2/complete_014_7_Book_00EE_09_E066-04.gex +a b = segment a b; c = lc_tangent c b a; d = midpoint d b c; e = on_circle e a b; f = on_line f d e, on_circle f a b ? eqangle e c c d d f f c +examples/complete2/unsolved2/complete_014_7_Book_00EE_08_E061-66.gex +a b = segment a b; c = s_angle b a c 60; d = foot d a b c; e = foot e b a c; g = circumcenter g b c a; f = on_line f a d, on_line f b e ? cong a f a g +examples/complete2/unsolved2/complete_011_7_Book_00EE_03_E037-24.gex +a b c = triangle a b c; d = circle d b a c; e = lc_tangent e a d, on_line e b c; f = angle_bisector f b e a, on_line f a b; g = on_line g a c, on_line g e f; h = angle_bisector h b a c, on_line h b c ? perp f e a h +examples/complete2/unsolved2/complete_014_7_Book_00EE_08_E061-65.gex +a b c d = isquare a b c d; e = s_angle c d e 15, s_angle d c e -15; f = reflect f e a c ? contri e a b a b e +examples/complete2/unsolved2/complete_012_7_Book_00EE_11_E076-31.gex +a b c = triangle a b c; d = angle_bisector d c b a, on_line d a c; e = angle_bisector e a c b, on_line e a b; f = on_line f d e, on_line f b c; g = on_line g a b ? eqangle a g a f a f a c +examples/complete2/unsolved2/complete_010_Other_Auxiliary_ye_aux_y1.gex +a b c = triangle a b c; d = angle_bisector d a b c, on_dia d b c; e = angle_bisector e b a c, on_dia e a c ? para d e a b +examples/complete2/unsolved2/complete_004_6_GDD_FULL_21-40_40.gex +a b c = triangle a b c; i = incenter i a b c; e = on_pline e i a b, on_line e a c ? cong e i e a +examples/complete2/unsolved2/complete_014_7_Book_00EE_09_E069-8.gex +a b c = triangle a b c; d = parallelogram a b c d; e = eqangle2 e d a b ? eqangle d a a e e c c d +examples/complete2/unsolved2/complete_011_7_Book_00EE_04_E051-9.gex +a b = segment a b; c = s_angle b a c 30; d = mirror d b c; e = foot e d a b ? cong d e a c +examples/complete2/unsolved2/complete_003_6_GDD_FULL_21-40_27.gex +a b c = triangle a b c; h = orthocenter h a b c; o = circumcenter o a b c; o3 = circumcenter o3 a h b; o1 = circumcenter o1 b h c; o2 = circumcenter o2 c h a ? cong h o1 h o2 +examples/complete2/unsolved2/complete_017_ex-gao_ex160_4_e08.gex +a b c = triangle a b c; d = on_line d b c, angle_bisector d b a c; e = on_pline e d a c, on_line e a b; f = on_pline f e b c, on_line f a c ? cong e a f c +examples/complete2/unsolved2/complete_015_7_Book_00EE_06_E051-29.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = on_circle e a b; f = on_line f b e, on_line f c d; g = on_line g d e, on_pline g f b c; h = on_circle h a b, on_dia h g a ? cong f g g h +examples/complete2/unsolved2/complete_002_6_GDD_FULL_41-60_42.gex +a b c = triangle a b c; d = incenter d a b c; e = foot e a b c; f = foot f b a d; g = foot g c a d; h = midpoint h c b ? cyclic e f g h +examples/complete2/unsolved2/complete_014_7_Book_00EE_08_E061-63f.gex +a b = segment a b; c = midpoint c b a; d = s_angle b a d 30, on_circle d c a; e = lc_tangent e d c, on_line e a b ? cong d a d e +examples/complete2/unsolved2/complete_007_7_Book_LLL_yL198-1.gex +c d = segment c d; e = midpoint e c d; b = free b; f = midpoint f b c; a = eqdistance a d c b, on_pline a b d c; g = midpoint g a b; h = midpoint h d a ? cong h e e f +examples/complete2/unsolved/complete_015_7_Book_00EE_08_E061-61.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = lc_tangent e c a, lc_tangent e d a; f = lc_tangent f b a, on_line f c e; h = on_pline h c b f, on_line h b d; g = on_line g d e, on_line g b f; i = on_line i b e, on_line i c h ? cong c i i h +examples/complete2/unsolved/complete_008_ex-gao_ex160_204.gex +a b c = triangle a b c; d = circumcenter d a b c; e = on_circle e d a; f = on_line f a b, on_line f c e; h = on_line h b c, angle_bisector h a f c; g = on_line g a e, on_line g b c; i = on_line i a b, angle_bisector i a g b; j = on_line j f h, on_line j g i; k = on_line k a e, on_line k f h ? perp g i f h +examples/complete2/unsolved/complete_005_Other_unsolved_65.gex +a b c = triangle a b c; e = on_line e a b; f = on_pline f e b c, on_line f a c; d = circle d a b c; g = circle g a e f ? coll a g d +examples/complete2/unsolved/complete_008_ex-gao_ex160_005.gex +a b c = triangle a b c; d = on_line d b c, angle_bisector d b a c; e = on_line e a b; f = on_line f a c, eqdistance f c b e; g = midpoint g f e; h = midpoint h c b ? para g h a d +examples/complete2/unsolved/complete_006_Other_ndgTest_65.gex +a b c = triangle a b c; e = on_line e a b; f = intersection_lp f a c e c b; d = circle d a b c; g = circle g a e f ? coll a g d +examples/complete2/unsolved/complete_005_Other_unsolved_E051-7.gex +a b c = triangle a b c; d = on_line d a b; e = angle_bisector e c b a, on_line e a c; f = on_pline f e a b, on_line f b c; g = angle_bisector g c b d, on_line g e f ? cong e f f g +examples/complete2/unsolved/complete_005_Other_unsolved_E046-10.gex +a b c d = isquare a b c d; e = midpoint e b a; f = on_line f a b; g = on_tline g e d e, angle_bisector g c b f ? cong d e e g +examples/complete2/unsolved/ex-gao_ex160_103.gex +c b y = triangle c b y; a = foot a c b y; x = angle_bisector x c b y; d = foot d a b c; e = on_line e b x, on_line e c a; f = foot f e b c; g = on_line g b x, on_line g a d ? cong f g f e +examples/complete2/unsolved/ex-gao_ex160_104.gex +c a y = triangle c a y; b = foot b c a y; x = angle_bisector x c a y; e = foot e b a c; d = on_line d a x, on_line d c b; f = on_line f a x, on_line f b e; g = on_line g c b, on_pline g f a c ? cong b d c g +examples/complete2/unsolved/complete_005_Other_unsolved_109f.gex +a b d = triangle a b d; e = on_line e a d; c = on_line c a b; f = on_line f b e, on_line f c d; g = circle g d e f; h = circle h a c d; i = circle i b c f; j = circle j b a e ? cyclic i j h g +examples/complete2/unsolved/complete_005_Other_unsolved_E046-7.gex +a b c = triangle a b c; d = midpoint d c a; e = angle_bisector e b a d, on_line e b d; f = on_pline f b c e, on_line f a c ? cong b a c f +examples/complete2/unsolved/complete_008_ex-gao_ex160_e121.gex +a b = segment a b; d = midpoint d b a; c = on_tline c a a b; e = on_circle e c a; f = on_line f d e, on_circle f c a; g = on_circle g c f, on_line g f b; h = on_line h b e, on_circle h c a ? para a b g h +examples/complete2/unsolved/complete_018_ex-gao_ex160_4_010.gex +a b c d = isquare a b c d; e = mirror e a b; f = midpoint f b a; g = on_tline g f d f, angle_bisector g c b e; h = foot h g a b ? cong d f g f +examples/complete2/unsolved/complete_005_Other_unsolved_82.gex +a b c = triangle a b c; o = incenter o a b c; i = foot i c a o; e = on_tline e a a o; j = foot j c a e; l = foot l c b o ? coll i l j +examples/complete2/unsolved/complete_014_7_Book_00EE_07_E057-41.gex +a b c = triangle a b c; d = foot d c a b; e = free e; f = on_circle f c e; g = on_line g d f, on_circle g c e; h = on_line h d e, on_circle h c e; i = on_line i f h, on_line i a b; j = on_line j e g, on_line j a b ? cong j d d i +examples/complete2/unsolved/complete_015_7_Book_00EE_08_E059-55.gex +a b = segment a b; c = on_circle c a b; d = lc_tangent d b a, lc_tangent d c a; e = on_circle e a b; f = on_line f d e, on_circle f a b; g = angle_bisector g f b e, on_line g d e ? cong d b d g +examples/complete2/unsolved/complete_005_Other_unsolved_E073-17.gex +a b c = triangle a b c; o = circumcenter o a b c; p = lc_tangent p a o, on_line p b c; d = on_circle d p a; e = intersection_lc e d o b; f = intersection_lc f d o c ? para e f p d +examples/complete2/unsolved/complete_015_7_Book_00EE_06_E056-33.gex +a b = segment a b; c = on_tline c a a b, on_circle c a b; e = s_angle b a e 60, on_circle e a b; d = s_angle b a d 30, on_circle d a b; f = on_line f a e, on_line f b c; g = on_line g a d, on_line g b c ? cong c f g b +examples/complete2/unsolved/complete_005_Other_unsolved_E074-24.gex +a b = segment a b; c = on_tline c a a b; d = foot d a b c; e = midpoint e d a; f = on_line f b e, on_line f a c; g = foot g f b c; h = midpoint h c a; i = on_tline i f a c, on_circle i h a ? cong f i f g +examples/complete2/unsolved1/complete_008_7_Book_LLL_L057-3.gex +a b = segment a b; c = on_bline c a b; e = midpoint e a c; d = mirror d c a; f = midpoint f b d ? cong e b f b +examples/complete2/unsolved1/complete_006_7_Book_LLL_L046-17.gex +c a b = risos c a b; d = midpoint d b a; e = on_line e a b; f = foot f e a c; g = foot g e b c ? cong d f d g +examples/complete2/unsolved1/complete_008_7_Book_LLL_L057-2.gex +c a b = triangle a b c; d = angle_bisector d b a c; e = foot e c a d; f = intersection_lp f a c e a b ? cong f e f c +examples/complete2/unsolved1/complete_008_ex-gao_ex160_e102.gex +a b c = triangle a b c; d = on_line d a c, angle_bisector d c b a; e = on_line e a b, angle_bisector e b c a; f = foot f a c e; g = foot g a b d ? para g f b c +examples/complete2/unsolved1/complete_012_7_Book_00EE_11_E075-27f.gex +a b c = triangle a b c; d = foot d c a b; e = on_line e c d; f = on_line f b e, on_line f a c; g = on_line g a e, on_line g b c ? eqangle d f d c d c d g +examples/complete2/unsolved1/complete_006_7_Book_LLL_L091-13.gex +b c d = triangle b c d; e = midpoint e c d; a = eqdistance a d c b, on_pline a b d c; f = midpoint f b a ? perp a b e f +examples/complete2/unsolved1/complete_010_Other_gao_Y_yL157-1.gex +a b d = triangle a b d; e = midpoint e d a; c = angle_bisector c a b d, on_dia c a b ? para c e b d +examples/complete2/unsolved1/complete_008_7_Book_LLL_L055-5.gex +d b a = triangle d b a; c = angle_bisector c d a b, angle_bisector c d b a; e = on_line e b c, on_tline e d b c; f = on_line f a c, on_tline f d a c ? para e f a b +examples/complete2/unsolved1/complete_013_7_Book_00EE_11_E075-29.gex +a b = segment a b; d = midpoint d b a; f = on_circle f d a; c = intersection_lt c a b f d f; e = midpoint e c a; g = on_tline g b a b, on_circle g e a ? eqangle f c f g g f g c +examples/complete2/unsolved1/complete_012_7_Book_00EE_05_E051-14-1.gex +a b c d = quadrangle a b c d; e = midpoint e b a; f = foot f a c d; g = foot g b c d ? cong e f e g +examples/complete2/unsolved1/complete_007_7_Book_LLL_L057-3-1.gex +a b = segment a b; c = on_bline c a b; e = midpoint e a c; f = mirror f b e; d = on_circle d a c, on_line d a c ? cong f b b d +examples/complete2/unsolved1/complete_001_6_GDD_FULL_61-80_80.gex +a b c = triangle a b c; o = circle o a b c; u = angle_bisector u b a c, on_line u b c; t = on_tline t a a o, on_line t b c ? cong t a t u +examples/complete2/unsolved1/complete_008_ex-gao_ex160_e213.gex +a b c = triangle a b c; d = midpoint d a b; e = midpoint e c a; f = on_line f d e, angle_bisector f c b a ? perp a f b f +examples/complete2/unsolved1/complete_011_7_Book_00EE_04_E051-2.gex +a b c = triangle a b c; d = on_line d a b; e = on_pline e d b c, on_line e a c; f = on_line f c d, on_line f b e; g = on_line g a f, on_line g b c ? midp g b c +examples/complete2/unsolved1/complete_007_7_Book_LLL_L043-5.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = on_circle e a b; f = intersection_ll f b c d e; g = intersection_ll g c d b e; h = angle_bisector h d f b; i = angle_bisector i c g b; j = intersection_ll j f h g i ? perp g i f h +examples/complete2/unsolved1/complete_001_6_GDD_FULL_61-80_71.gex +a b c = triangle a b c; o = circle o a b c; e = on_pline e a b c, on_circle e o a; f = foot f e a b; g = foot g e a c ? para f g a o +examples/complete2/unsolved1/complete_007_7_Book_LLL_L043-5-1.gex +a b = segment a b; c = on_circle c a b; d = on_circle d a b; e = on_circle e a b; f = intersection_ll f b c d e; g = intersection_ll g c d b e; h = angle_bisector h d f b; i = angle_bisector i c g b; j = intersection_ll j f h g i; k = intersection_ll k d e g i; l = intersection_ll l g i b c ? perp g i f h +examples/complete2/unsolved1/complete_011_7_Book_00EE_04_E051-8.gex +a b c = triangle a b c; d = angle_bisector d a c b, on_line d a b; e = on_pline e d b c, on_line e a c; f = on_pline f e a b, on_line f b c ? cong c e f b +examples/complete2/unsolved1/complete_013_7_Book_00EE_10_E072-8.gex +a b c = triangle a b c; d = angle_bisector d b a c, on_line d b c; f = on_line f b c, on_bline f a d; e = on_bline e a d, on_line e a d ? eqangle a b a f c f c a +examples/complete2/unsolved1/complete_003_6_GDD_FULL_more_E023-14.gex +a b c = triangle a b c; d = midpoint d b a; e = angle_bisector e c d a, on_line e a c; f = angle_bisector f c d b, on_line f c b ? para e f a b +examples/complete2/unsolved1/complete_010_Other_gao_Y_yL182-1.gex +a c d = triangle a c d; b = on_pline b c d a, on_pline b a d c; e = on_line e a c; f = shift f c a e ? para d e f b +examples/complete2/unsolved1/complete_008_ex-gao_ex160_e120.gex +a b = segment a b; c = midpoint c b a; d = on_circle d c a; e = on_tline e a a b, on_tline e d c d; f = on_tline f b a b, on_line f d e; g = on_line g b e, on_line g a f ? para d g a e +new_unsolved/0.gex +c d b = triangle c d b; e = midpoint e c d; a = eqdistance a d c b, on_pline a b d c; f = midpoint f b a ? perp a b e f +new_unsolved/1.gex +a b c d = eq_trapezoid a b c d ? eqangle a d a b b a b c +examples/complete2/unsolved1/complete_009_Other_paper_Thebault_t5.gex +a b c = triangle a b c; d = circle d a b c; e = on_line e b c; f g h i = 2l1c f g h i a b e d; j k l m = 2l1c j k l m a c e d; n = incenter n a b c ? coll m n i +examples/complete2/unsolved/complete_013_7_Book_00EE_10_E072-11.gex +a b = segment b a; c = on_line c a b; d = on_circle d a b; e = on_circle e a b; g = on_line g d e, on_circle g c b; f = on_line f d e, on_circle f c b ? eqangle b e b f b g b d +examples/complete2/unsolved2/complete_015_7_Book_00EE_06_E051-28.gex +b c = segment b c; a = on_tline a b b c; d = on_circle d c b; e g = e5128 e g a b c d ? cong a g g b +examples/complete2/unsolved2/complete_010_Other_Auxiliary_ye_aux_think.gex +c a b = iso_triangle c a b; d e f = 3peq d e f c a b ? cong d a b e +examples/complete2/unsolved/morley.gex +a b c = triangle a b c; d e = trisect d e b a c; f g = trisect f g c b a; h i = trisect h i a c b; j = intersection_ll j b f c i; k = intersection_ll k a e c h; l = intersection_ll l a d b g ? cong j l j k diff --git a/lm_inference.py b/lm_inference.py new file mode 100644 index 0000000..d404b0b --- /dev/null +++ b/lm_inference.py @@ -0,0 +1,189 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Wrapper for language modeling inference implemented in Meliad.""" +from typing import Any, Dict + +import jax +import models # pylint: disable=unused-import +import t5.data +from transformer import inference_utils + + +np = jax.numpy + + +Trainer = inference_utils.Trainer + +MetricsOutput = Dict[str, Any] # Metrics output by model. + + +parse_gin_configuration = inference_utils.parse_gin_configuration + + +class LanguageModelInference: + """Meliad wrapper for LM inference.""" + + def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'): + self.vocab = t5.data.SentencePieceVocabulary(vocab_path) + + # This task won't be pulling from a dataset. + def null_iter_fn() -> None: + return None + + process_summaries_f = inference_utils.models.process_summaries_function( + self.vocab + ) + + trainer = inference_utils.training_loop.Trainer( + get_training_dataset_iterator=null_iter_fn, + get_test_dataset_iterator=None, + pretty_print_input_function=None, + process_summaries_function=process_summaries_f, + load_dir=load_dir, + workdir='', # Don't log or save checkpoints. + replicate_mode=False, + ) # Run on a single device at batch size 1. + self.trainer = trainer + + # Create and initialize the model. + (tstate, _, imodel, prngs) = trainer.initialize_model() + self.imodel = imodel + self.batch_size = imodel.task_config.batch_size + + self.n = imodel.num_heads + self.h = imodel.head_size + + # Create an inference task. + writers = {} + self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args + + # Register any additional actions. + # Actions are cleared first for use with colab. + inference_utils.training_loop.clear_interstep_callbacks() + inference_utils.training_loop.register_interstep_callbacks() + self.tstate = tstate + + # some default parameters. + eos = [0] * 1024 + for idx in self.encode_list(['.', ';']): + eos[idx] = 1 + + self.eos = np.array(eos, dtype=np.bfloat16) + self.mask = jax.numpy.ones([1024], dtype=np.bfloat16) + + def decode(self, ids: list[int]) -> str: + return self.vocab.decode(ids) + + def decode_list(self, tokens: list[int]) -> list[str]: + return [self.decode([tok]) for tok in tokens] + + def encode(self, inputs_str: str) -> list[int]: + return self.vocab.encode(inputs_str) + + def encode_list(self, inputs_strs: list[str]) -> list[int]: + result = [self.vocab.encode(x) for x in inputs_strs] + assert all([len(x) == 1 for x in result]), [ + self.decode(x) for x in result if len(x) != 1 + ] + return [x[0] for x in result] + + def call( + self, + inputs: np.ndarray, + dstate: tuple[dict[str, np.ndarray], ...] = None, + eos: np.ndarray = None, + mask: np.ndarray = None, + ) -> MetricsOutput: + """Call the meliad model.""" + batch_size, length = inputs.shape + inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)]) + + if eos is None: + eos = self.eos + if mask is None: + mask = self.mask + + x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask} + + if dstate is not None: + x['start_of_sequence'] = jax.numpy.array([False] * batch_size) + else: + dstate = tuple( + [{ # this dummy value will never be used. + 'current_index': np.array([0] * batch_size, dtype=np.int32), + 'keys': np.zeros( + (batch_size, 2048, self.n, self.h), dtype=np.bfloat16 + ), + 'values': np.zeros( + (batch_size, 2048, self.n, self.h), dtype=np.bfloat16 + ), + 'recurrent_kvq': None, + 'relative_position_bias': np.zeros( + (batch_size, self.n, 1, 1024), dtype=np.bfloat16 + ), + }] + * 12 + ) + x['start_of_sequence'] = jax.numpy.array([True] * batch_size) + + x['dstate'] = dstate + _, metrics_np = self.task.run_step(self.tstate, x, 0) + return metrics_np + + def beam_decode( + self, + inputs: str, + eos_tokens: np.ndarray = None, + mask_tokens: np.ndarray = None, + dstate: dict[str, np.ndarray] = None, + ) -> MetricsOutput: + """Beam search.""" + inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size) + + eos = self.eos + if eos_tokens is not None: + eos_ids = self.encode_list(eos_tokens) + eos = np.array( + [1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.bfloat16 + ).reshape((1, 1, 1024)) + + mask = self.mask + if mask_tokens is not None: + mask_ids = self.encode_list(mask_tokens) + mask = np.array( + [0 if idx in mask_ids else 1 for idx in range(1024)], + dtype=np.bfloat16, + ).reshape((1, 1, 1024)) + + metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask) + + finished_seqs = metrics_np['finished_seqs'] + finished_scores = metrics_np['finished_scores'] + + seqs = [] + scores = [] + for seq, score in zip(finished_seqs, finished_scores): + seq = self.decode(seq[1:]) + seqs.append(seq) + scores.append(score) + + return { + 'finished_seqs': finished_seqs, + 'finished_scores': finished_scores, + 'seqs_str': seqs, + 'scores': scores, + 'dstate': metrics_np['dstate'], + } diff --git a/lm_inference_test.py b/lm_inference_test.py new file mode 100644 index 0000000..76571ce --- /dev/null +++ b/lm_inference_test.py @@ -0,0 +1,89 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for lm_inference.py.""" +import os +import unittest + +from absl import flags +from absl.testing import absltest +import lm_inference as lm + + +_DATA_PATH = flags.DEFINE_string('data_path', '', 'path to ckpt and vocab.') +_MELIAD_PATH = flags.DEFINE_string( + 'meliad_path', '', 'path to meliad repository.' +) # pylint: disable=line-too-long + + +class LmInferenceTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + gin_file = [ + 'base_htrans.gin', + 'size/medium_150M.gin', + 'options/positions_t5.gin', + 'options/lr_cosine_decay.gin', + 'options/seq_1024_nocache.gin', + 'geometry_150M_generate.gin', + ] + + gin_param = [ + 'DecoderOnlyLanguageModelGenerate.output_token_losses=True', + 'TransformerTaskConfig.batch_size=2', + 'TransformerTaskConfig.sequence_length=128', + 'Trainer.restore_state_variables=False', + ] + + gin_search_paths = [ + os.path.join(_MELIAD_PATH.value, 'transformer/configs'), + os.getcwd(), + ] + + vocab_path = os.path.join(_DATA_PATH.value, 'geometry.757.model') + + lm.parse_gin_configuration(gin_file, gin_param, gin_paths=gin_search_paths) + + cls.loaded_lm = lm.LanguageModelInference( + vocab_path, _DATA_PATH.value, mode='beam_search' + ) + + def test_lm_decode(self): + outputs = LmInferenceTest.loaded_lm.beam_decode( + '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c' + ' {F1} x00', + eos_tokens=[';'], + ) + self.assertEqual( + outputs['seqs_str'], + ['e : D a b c e 02 D a c b e 03 ;', 'e : C a c e 02 C b d e 03 ;'], + ) + + def test_lm_score_may_fail_numerically_for_external_meliad(self): + outputs = LmInferenceTest.loaded_lm.beam_decode( + '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c' + ' {F1} x00', + eos_tokens=[';'], + ) + self.assertEqual( + outputs['scores'], + [-1.18607294559478759765625, -1.10228693485260009765625], + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/models.py b/models.py new file mode 100644 index 0000000..0a994c9 --- /dev/null +++ b/models.py @@ -0,0 +1,178 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Transformer language model generate mode.""" + +from typing import Any, Tuple +import beam_search +import decoder_stack +import gin +import jax +import jax.numpy as jnp +from transformer import models + + +@gin.configurable +class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel): + """Decoder only language modeling in inference mode.""" + + decoder_factory = decoder_stack.DecoderStackGenerate + + num_heads: int = gin.REQUIRED + head_size: int = gin.REQUIRED + + def get_fake_input(self) -> dict[str, Any]: + fake_input_dict = super().get_fake_input() + b = self.task_config.batch_size + n = self.num_heads + h = self.head_size + fake_input_dict.update({ + 'dstate': tuple( + [{ + 'current_index': jnp.array([0] * b, dtype=jnp.int32), + 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16), + 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16), + 'recurrent_kvq': None, + 'relative_position_bias': jnp.zeros( + (b, n, 1, 1024), dtype=jnp.bfloat16 + ), + }] + * 12 + ), + 'eos': jnp.zeros([1024], dtype=jnp.bfloat16), + 'mask': jnp.ones([1024], dtype=jnp.bfloat16), + 'length': 1, + 'temperature': 1.0, + }) + return fake_input_dict + + def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]: + # Make sure this code is not used on untested cases. + if self.mode not in ['init', 'beam_search']: + raise ValueError(f'{type(self)} cannot do mode {self.mode}') + if self.decoder.supports_generate(): + raise ValueError(f'{type(self)}.decoder cannot supports_generate()') + + self.decoder( + input_tokens=inputs['targets'][:, 0:1], + target_tokens=None, + start_of_sequence=inputs['start_of_sequence'], + ) + + b = inputs['targets'].shape[0] + no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_) + + # This fn is used in both beam_search or topk_sampling. + def tokens_to_logits_fn( + input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...] + ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]: + (logits, dstate, _) = self.decoder( + input_tokens=input_token, + target_tokens=None, + start_of_sequence=no_start_of_seq, + decoder_state=dstate, + ) + return logits[:, -1, :], dstate + + last_token = jax.lax.dynamic_slice_in_dim( + inputs['targets'], inputs['length'] - 1, 1, axis=1 + ) + + # last token is used to seed beam_search + inputs['targets'] = inputs['targets'][:, 0:-1] + dstate = jax.lax.cond( + inputs['start_of_sequence'][0], + lambda: self.generate(inputs)[0], + lambda: inputs['dstate'], + ) + + # Then we run beam search, init with last_token & dstate. + finished_seqs, finished_scores, dstate = beam_search.beam_search_flat( + last_token, + dstate, + tokens_to_logits_fn, + max_decode_len=512, + eos=inputs['eos'].reshape((1, 1, -1)), + mask=inputs['mask'].reshape((1, 1, -1)), + ) + + return 0.0, { + 'finished_seqs': finished_seqs, + 'finished_scores': finished_scores, + 'dstate': dstate, + } + + def generate( + self, inputs: ... + ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]: + """Generate an output sequence. + + Args: + inputs: the same as argument to _call_. + + Returns: + An array of generated tokens of shape (batch_size, sequence_length). + """ + input_tokens = inputs['targets'] # [b,seq_len] + start_of_sequence = inputs['start_of_sequence'] # [b] + target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)]) + batch_size = target_tokens.shape[0] + + # Assuming all sequences start at the same time. + start0 = inputs['start_of_sequence'][0] + dstate = jax.lax.cond( + start0, + lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda + 1024, start_of_sequence + ), + lambda: inputs['dstate'], + ) + + first_token = input_tokens[:, 0:1] + no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_) + temperature = 1 + if 'temperature' in inputs: + temperature = inputs['temperature'] + + num_steps = inputs['length'] + if self.mode == 'beam_search': + num_steps -= 1 + + def cond_fn(scan_state) -> jnp.bool_: + _, _, i, _ = scan_state + return i < num_steps + + def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]: + (dstate, input_token, i, _) = scan_state + + (logits, dstate, _) = self.decoder( + input_tokens=input_token, + target_tokens=None, + start_of_sequence=no_start_of_seq, + decoder_state=dstate, + ) + + logits = logits / temperature + output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1) + + return (dstate, output_token, i + 1, logits) + + # Scan over the sequence length. + dummy_logits = jnp.zeros((batch_size, 1, 1024)) + initial_scan_state = (dstate, first_token, 0, dummy_logits) + dstate, _, _, logits = jax.lax.while_loop( + cond_fn, loop_fn, initial_scan_state + ) + return dstate, logits diff --git a/numericals.py b/numericals.py new file mode 100644 index 0000000..bbe03fe --- /dev/null +++ b/numericals.py @@ -0,0 +1,1921 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Numerical representation of geometry.""" +from __future__ import annotations + +import math +from typing import Any, Optional, Union + +import geometry as gm +import matplotlib +from matplotlib import pyplot as plt +import matplotlib.colors as mcolors +import numpy as np +from numpy.random import uniform as unif # pylint: disable=g-importing-member + + +matplotlib.use('TkAgg') + + +ATOM = 1e-12 + + +# Some variables are there for better code reading. +# pylint: disable=unused-assignment +# pylint: disable=unused-argument +# pylint: disable=unused-variable + +# Naming in geometry is a little different +# we stick to geometry naming to better read the code. +# pylint: disable=invalid-name + + +class Point: + """Numerical point.""" + + def __init__(self, x, y): + self.x = x + self.y = y + + def __lt__(self, other: Point) -> bool: + return (self.x, self.y) < (other.x, other.y) + + def __gt__(self, other: Point) -> bool: + return (self.x, self.y) > (other.x, other.y) + + def __add__(self, p: Point) -> Point: + return Point(self.x + p.x, self.y + p.y) + + def __sub__(self, p: Point) -> Point: + return Point(self.x - p.x, self.y - p.y) + + def __mul__(self, f: float) -> Point: + return Point(self.x * f, self.y * f) + + def __rmul__(self, f: float) -> Point: + return self * f + + def __truediv__(self, f: float) -> Point: + return Point(self.x / f, self.y / f) + + def __floordiv__(self, f: float) -> Point: + div = self / f # true div + return Point(int(div.x), int(div.y)) + + def __str__(self) -> str: + return 'P({},{})'.format(self.x, self.y) + + def close(self, point: Point, tol: float = 1e-12) -> bool: + return abs(self.x - point.x) < tol and abs(self.y - point.y) < tol + + def midpoint(self, p: Point) -> Point: + return Point(0.5 * (self.x + p.x), 0.5 * (self.y + p.y)) + + def distance(self, p: Union[Point, Line, Circle]) -> float: + if isinstance(p, Line): + return p.distance(self) + if isinstance(p, Circle): + return abs(p.radius - self.distance(p.center)) + dx = self.x - p.x + dy = self.y - p.y + return np.sqrt(dx * dx + dy * dy) + + def distance2(self, p: Point) -> float: + if isinstance(p, Line): + return p.distance(self) + dx = self.x - p.x + dy = self.y - p.y + return dx * dx + dy * dy + + def rotatea(self, ang: float) -> Point: + sinb, cosb = np.sin(ang), np.cos(ang) + return self.rotate(sinb, cosb) + + def rotate(self, sinb: float, cosb: float) -> Point: + x, y = self.x, self.y + return Point(x * cosb - y * sinb, x * sinb + y * cosb) + + def flip(self) -> Point: + return Point(-self.x, self.y) + + def perpendicular_line(self, line: Line) -> Line: + return line.perpendicular_line(self) + + def foot(self, line: Line) -> Point: + if isinstance(line, Line): + l = line.perpendicular_line(self) + return line_line_intersection(l, line) + elif isinstance(line, Circle): + c, r = line.center, line.radius + return c + (self - c) * r / self.distance(c) + raise ValueError('Dropping foot to weird type {}'.format(type(line))) + + def parallel_line(self, line: Line) -> Line: + return line.parallel_line(self) + + def norm(self) -> float: + return np.sqrt(self.x**2 + self.y**2) + + def cos(self, other: Point) -> float: + x, y = self.x, self.y + a, b = other.x, other.y + return (x * a + y * b) / self.norm() / other.norm() + + def dot(self, other: Point) -> float: + return self.x * other.x + self.y * other.y + + def sign(self, line: Line) -> int: + return line.sign(self) + + def is_same(self, other: Point) -> bool: + return self.distance(other) <= ATOM + + +class Line: + """Numerical line.""" + + def __init__( + self, + p1: Point = None, + p2: Point = None, + coefficients: tuple[int, int, int] = None, + ): + if p1 is None and p2 is None and coefficients is None: + self.coefficients = None, None, None + return + + a, b, c = coefficients or ( + p1.y - p2.y, + p2.x - p1.x, + p1.x * p2.y - p2.x * p1.y, + ) + + # Make sure a is always positive (or always negative for that matter) + # With a == 0, Assuming a = +epsilon > 0 + # Then b such that ax + by = 0 with y>0 should be negative. + if a < 0.0 or a == 0.0 and b > 0.0: + a, b, c = -a, -b, -c + + self.coefficients = a, b, c + + def parallel_line(self, p: Point) -> Line: + a, b, _ = self.coefficients + return Line(coefficients=(a, b, -a * p.x - b * p.y)) # pylint: disable=invalid-unary-operand-type + + def perpendicular_line(self, p: Point) -> Line: + a, b, _ = self.coefficients + return Line(p, p + Point(a, b)) + + def greater_than(self, other: Line) -> bool: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + # b/a > y/x + return b * x > a * y + + def __gt__(self, other: Line) -> bool: + return self.greater_than(other) + + def __lt__(self, other: Line) -> bool: + return other.greater_than(self) + + def same(self, other: Line) -> bool: + a, b, c = self.coefficients + x, y, z = other.coefficients + return close_enough(a * y, b * x) and close_enough(b * z, c * y) + + def equal(self, other: Line) -> bool: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + # b/a == y/x + return b * x == a * y + + def less_than(self, other: Line) -> bool: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + # b/a > y/x + return b * x < a * y + + def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]: + if isinstance(obj, Line): + return line_line_intersection(self, obj) + if isinstance(obj, Circle): + return line_circle_intersection(self, obj) + + def distance(self, p: Point) -> float: + a, b, c = self.coefficients + return abs(self(p.x, p.y)) / math.sqrt(a * a + b * b) + + def __call__(self, x: Point, y: Point = None) -> float: + if isinstance(x, Point) and y is None: + return self(x.x, x.y) + a, b, c = self.coefficients + return x * a + y * b + c + + def is_parallel(self, other: Line) -> bool: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + return abs(a * y - b * x) < ATOM + + def is_perp(self, other: Line) -> bool: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + return abs(a * x + b * y) < ATOM + + def cross(self, other: Line) -> float: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + return a * y - b * x + + def dot(self, other: Line) -> float: + a, b, _ = self.coefficients + x, y, _ = other.coefficients + return a * x + b * y + + def point_at(self, x: float = None, y: float = None) -> Optional[Point]: + """Get a point on line closest to (x, y).""" + a, b, c = self.coefficients + # ax + by + c = 0 + if x is None and y is not None: + if a != 0: + return Point((-c - b * y) / a, y) # pylint: disable=invalid-unary-operand-type + else: + return None + elif x is not None and y is None: + if b != 0: + return Point(x, (-c - a * x) / b) # pylint: disable=invalid-unary-operand-type + else: + return None + elif x is not None and y is not None: + if a * x + b * y + c == 0.0: + return Point(x, y) + return None + + def diff_side(self, p1: Point, p2: Point) -> Optional[bool]: + d1 = self(p1.x, p1.y) + d2 = self(p2.x, p2.y) + if d1 == 0 or d2 == 0: + return None + return d1 * d2 < 0 + + def same_side(self, p1: Point, p2: Point) -> Optional[bool]: + d1 = self(p1.x, p1.y) + d2 = self(p2.x, p2.y) + if d1 == 0 or d2 == 0: + return None + return d1 * d2 > 0 + + def sign(self, point: Point) -> int: + s = self(point.x, point.y) + if s > 0: + return 1 + elif s < 0: + return -1 + return 0 + + def is_same(self, other: Line) -> bool: + a, b, c = self.coefficients + x, y, z = other.coefficients + return abs(a * y - b * x) <= ATOM and abs(b * z - c * y) <= ATOM + + def sample_within(self, points: list[Point], n: int = 5) -> list[Point]: + """Sample a point within the boundary of points.""" + center = sum(points, Point(0.0, 0.0)) * (1.0 / len(points)) + radius = max([p.distance(center) for p in points]) + if close_enough(center.distance(self), radius): + center = center.foot(self) + a, b = line_circle_intersection(self, Circle(center.foot(self), radius)) + + result = None + best = -1.0 + for _ in range(n): + rand = unif(0.0, 1.0) + x = a + (b - a) * rand + mind = min([x.distance(p) for p in points]) + if mind > best: + best = mind + result = x + + return [result] + + +class InvalidLineIntersectError(Exception): + pass + + +class HalfLine(Line): + """Numerical ray.""" + + def __init__(self, tail: Point, head: Point): # pylint: disable=super-init-not-called + self.line = Line(tail, head) + self.coefficients = self.line.coefficients + self.tail = tail + self.head = head + + def intersect(self, obj: Union[Line, HalfLine, Circle, HoleCircle]) -> Point: + if isinstance(obj, (HalfLine, Line)): + return line_line_intersection(self.line, obj) + + exclude = [self.tail] + if isinstance(obj, HoleCircle): + exclude += [obj.hole] + + a, b = line_circle_intersection(self.line, obj) + if any([a.close(x) for x in exclude]): + return b + if any([b.close(x) for x in exclude]): + return a + + v = self.head - self.tail + va = a - self.tail + vb = b - self.tail + if v.dot(va) > 0: + return a + if v.dot(vb) > 0: + return b + raise InvalidLineIntersectError() + + def sample_within(self, points: list[Point], n: int = 5) -> list[Point]: + center = sum(points, Point(0.0, 0.0)) * (1.0 / len(points)) + radius = max([p.distance(center) for p in points]) + if close_enough(center.distance(self.line), radius): + center = center.foot(self) + a, b = line_circle_intersection(self, Circle(center.foot(self), radius)) + + if (a - self.tail).dot(self.head - self.tail) > 0: + a, b = self.tail, a + else: + a, b = self.tail, b # pylint: disable=self-assigning-variable + + result = None + best = -1.0 + for _ in range(n): + x = a + (b - a) * unif(0.0, 1.0) + mind = min([x.distance(p) for p in points]) + if mind > best: + best = mind + result = x + + return [result] + + +def _perpendicular_bisector(p1: Point, p2: Point) -> Line: + midpoint = (p1 + p2) * 0.5 + return Line(midpoint, midpoint + Point(p2.y - p1.y, p1.x - p2.x)) + + +def same_sign( + a: Point, b: Point, c: Point, d: Point, e: Point, f: Point +) -> bool: + a, b, c, d, e, f = map(lambda p: p.sym, [a, b, c, d, e, f]) + ab, cb = a - b, c - b + de, fe = d - e, f - e + return (ab.x * cb.y - ab.y * cb.x) * (de.x * fe.y - de.y * fe.x) > 0 + + +class Circle: + """Numerical circle.""" + + def __init__( + self, + center: Optional[Point] = None, + radius: Optional[float] = None, + p1: Optional[Point] = None, + p2: Optional[Point] = None, + p3: Optional[Point] = None, + ): + if not center: + if not (p1 and p2 and p3): + self.center = self.radius = self.r2 = None + return + # raise ValueError('Circle without center need p1 p2 p3') + + l12 = _perpendicular_bisector(p1, p2) + l23 = _perpendicular_bisector(p2, p3) + center = line_line_intersection(l12, l23) + + self.center = center + self.a, self.b = center.x, center.y + + if not radius: + if not (p1 or p2 or p3): + raise ValueError('Circle needs radius or p1 or p2 or p3') + p = p1 or p2 or p3 + self.r2 = (self.a - p.x) ** 2 + (self.b - p.y) ** 2 + self.radius = math.sqrt(self.r2) + else: + self.radius = radius + self.r2 = radius * radius + + def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]: + if isinstance(obj, Line): + return obj.intersect(self) + if isinstance(obj, Circle): + return circle_circle_intersection(self, obj) + + def sample_within(self, points: list[Point], n: int = 5) -> list[Point]: + """Sample a point within the boundary of points.""" + result = None + best = -1.0 + for _ in range(n): + ang = unif(0.0, 2.0) * np.pi + x = self.center + Point(np.cos(ang), np.sin(ang)) * self.radius + mind = min([x.distance(p) for p in points]) + if mind > best: + best = mind + result = x + + return [result] + + +class HoleCircle(Circle): + """Numerical circle with a missing point.""" + + def __init__(self, center: Point, radius: float, hole: Point): + super().__init__(center, radius) + self.hole = hole + + def intersect(self, obj: Union[Line, HalfLine, Circle, HoleCircle]) -> Point: + if isinstance(obj, Line): + a, b = line_circle_intersection(obj, self) + if a.close(self.hole): + return b + return a + if isinstance(obj, HalfLine): + return obj.intersect(self) + if isinstance(obj, Circle): + a, b = circle_circle_intersection(obj, self) + if a.close(self.hole): + return b + return a + if isinstance(obj, HoleCircle): + a, b = circle_circle_intersection(obj, self) + if a.close(self.hole) or a.close(obj.hole): + return b + return a + + +def solve_quad(a: float, b: float, c: float) -> tuple[float, float]: + """Solve a x^2 + bx + c = 0.""" + a = 2 * a + d = b * b - 2 * a * c + if d < 0: + return None # the caller should expect this result. + + y = math.sqrt(d) + return (-b - y) / a, (-b + y) / a + + +def circle_circle_intersection(c1: Circle, c2: Circle) -> tuple[Point, Point]: + """Returns a pair of Points as intersections of c1 and c2.""" + # circle 1: (x0, y0), radius r0 + # circle 2: (x1, y1), radius r1 + x0, y0, r0 = c1.a, c1.b, c1.radius + x1, y1, r1 = c2.a, c2.b, c2.radius + + d = math.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) + if d == 0: + raise InvalidQuadSolveError() + + a = (r0**2 - r1**2 + d**2) / (2 * d) + h = r0**2 - a**2 + if h < 0: + raise InvalidQuadSolveError() + h = np.sqrt(h) + x2 = x0 + a * (x1 - x0) / d + y2 = y0 + a * (y1 - y0) / d + x3 = x2 + h * (y1 - y0) / d + y3 = y2 - h * (x1 - x0) / d + x4 = x2 - h * (y1 - y0) / d + y4 = y2 + h * (x1 - x0) / d + + return Point(x3, y3), Point(x4, y4) + + +class InvalidQuadSolveError(Exception): + pass + + +def line_circle_intersection(line: Line, circle: Circle) -> tuple[Point, Point]: + """Returns a pair of points as intersections of line and circle.""" + a, b, c = line.coefficients + r = float(circle.radius) + center = circle.center + p, q = center.x, center.y + + if b == 0: + x = -c / a + x_p = x - p + x_p2 = x_p * x_p + y = solve_quad(1, -2 * q, q * q + x_p2 - r * r) + if y is None: + raise InvalidQuadSolveError() + y1, y2 = y + return (Point(x, y1), Point(x, y2)) + + if a == 0: + y = -c / b + y_q = y - q + y_q2 = y_q * y_q + x = solve_quad(1, -2 * p, p * p + y_q2 - r * r) + if x is None: + raise InvalidQuadSolveError() + x1, x2 = x + return (Point(x1, y), Point(x2, y)) + + c_ap = c + a * p + a2 = a * a + y = solve_quad( + a2 + b * b, 2 * (b * c_ap - a2 * q), c_ap * c_ap + a2 * (q * q - r * r) + ) + if y is None: + raise InvalidQuadSolveError() + y1, y2 = y + + return Point(-(b * y1 + c) / a, y1), Point(-(b * y2 + c) / a, y2) + + +def _check_between(a: Point, b: Point, c: Point) -> bool: + """Whether a is between b & c.""" + return (a - b).dot(c - b) > 0 and (a - c).dot(b - c) > 0 + + +def circle_segment_intersect( + circle: Circle, p1: Point, p2: Point +) -> list[Point]: + l = Line(p1, p2) + px, py = line_circle_intersection(l, circle) + + result = [] + if _check_between(px, p1, p2): + result.append(px) + if _check_between(py, p1, p2): + result.append(py) + return result + + +def line_segment_intersection(l: Line, A: Point, B: Point) -> Point: # pylint: disable=invalid-name + a, b, c = l.coefficients + x1, y1, x2, y2 = A.x, A.y, B.x, B.y + dx, dy = x2 - x1, y2 - y1 + alpha = (-c - a * x1 - b * y1) / (a * dx + b * dy) + return Point(x1 + alpha * dx, y1 + alpha * dy) + + +def line_line_intersection(l1: Line, l2: Line) -> Point: + a1, b1, c1 = l1.coefficients + a2, b2, c2 = l2.coefficients + # a1x + b1y + c1 = 0 + # a2x + b2y + c2 = 0 + d = a1 * b2 - a2 * b1 + if d == 0: + raise InvalidLineIntersectError + return Point((c2 * b1 - c1 * b2) / d, (c1 * a2 - c2 * a1) / d) + + +def check_too_close( + newpoints: list[Point], points: list[Point], tol: int = 0.1 +) -> bool: + if not points: + return False + avg = sum(points, Point(0.0, 0.0)) * 1.0 / len(points) + mindist = min([p.distance(avg) for p in points]) + for p0 in newpoints: + for p1 in points: + if p0.distance(p1) < tol * mindist: + return True + return False + + +def check_too_far( + newpoints: list[Point], points: list[Point], tol: int = 4 +) -> bool: + if len(points) < 2: + return False + avg = sum(points, Point(0.0, 0.0)) * 1.0 / len(points) + maxdist = max([p.distance(avg) for p in points]) + for p in newpoints: + if p.distance(avg) > maxdist * tol: + return True + return False + + +def check_aconst(args: list[Point]) -> bool: + a, b, c, d, num, den = args + d = d + a - c + ang = ang_between(a, b, d) + if ang < 0: + ang += np.pi + return close_enough(ang, num * np.pi / den) + + +def check(name: str, args: list[Union[gm.Point, Point]]) -> bool: + """Numerical check.""" + if name == 'eqangle6': + name = 'eqangle' + elif name == 'eqratio6': + name = 'eqratio' + elif name in ['simtri2', 'simtri*']: + name = 'simtri' + elif name in ['contri2', 'contri*']: + name = 'contri' + elif name == 'para': + name = 'para_or_coll' + elif name == 'on_line': + name = 'coll' + elif name in ['rcompute', 'acompute']: + return True + elif name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']: + return True + + fn_name = 'check_' + name + if fn_name not in globals(): + return None + + fun = globals()['check_' + name] + args = [p.num if isinstance(p, gm.Point) else p for p in args] + return fun(args) + + +def check_circle(points: list[Point]) -> bool: + if len(points) != 4: + return False + o, a, b, c = points + oa, ob, oc = o.distance(a), o.distance(b), o.distance(c) + return close_enough(oa, ob) and close_enough(ob, oc) + + +def check_coll(points: list[Point]) -> bool: + a, b = points[:2] + l = Line(a, b) + for p in points[2:]: + if abs(l(p.x, p.y)) > ATOM: + return False + return True + + +def check_ncoll(points: list[Point]) -> bool: + return not check_coll(points) + + +def check_sameside(points: list[Point]) -> bool: + b, a, c, y, x, z = points + # whether b is to the same side of a & c as y is to x & z + ba = b - a + bc = b - c + yx = y - x + yz = y - z + return ba.dot(bc) * yx.dot(yz) > 0 + + +def check_para_or_coll(points: list[Point]) -> bool: + return check_para(points) or check_coll(points) + + +def check_para(points: list[Point]) -> bool: + a, b, c, d = points + ab = Line(a, b) + cd = Line(c, d) + if ab.same(cd): + return False + return ab.is_parallel(cd) + + +def check_perp(points: list[Point]) -> bool: + a, b, c, d = points + ab = Line(a, b) + cd = Line(c, d) + return ab.is_perp(cd) + + +def check_cyclic(points: list[Point]) -> bool: + points = list(set(points)) + (a, b, c), *ps = points + circle = Circle(p1=a, p2=b, p3=c) + for d in ps: + if not close_enough(d.distance(circle.center), circle.radius): + return False + return True + + +def bring_together( + a: Point, b: Point, c: Point, d: Point +) -> tuple[Point, Point, Point, Point]: + ab = Line(a, b) + cd = Line(c, d) + x = line_line_intersection(ab, cd) + unit = Circle(center=x, radius=1.0) + y, _ = line_circle_intersection(ab, unit) + z, _ = line_circle_intersection(cd, unit) + return x, y, x, z + + +def same_clock( + a: Point, b: Point, c: Point, d: Point, e: Point, f: Point +) -> bool: + ba = b - a + cb = c - b + ed = e - d + fe = f - e + return (ba.x * cb.y - ba.y * cb.x) * (ed.x * fe.y - ed.y * fe.x) > 0 + + +def check_const_angle(points: list[Point]) -> bool: + """Check if the angle is equal to the given constant.""" + a, b, c, d, m, n = points + a, b, c, d = bring_together(a, b, c, d) + ba = b - a + dc = d - c + + a3 = np.arctan2(ba.y, ba.x) + a4 = np.arctan2(dc.y, dc.x) + y = a3 - a4 + + return close_enough(m / n % 1, y / np.pi % 1) + + +def check_eqangle(points: list[Point]) -> bool: + """Check if 8 points make 2 equal angles.""" + a, b, c, d, e, f, g, h = points + + ab = Line(a, b) + cd = Line(c, d) + ef = Line(e, f) + gh = Line(g, h) + + if ab.is_parallel(cd): + return ef.is_parallel(gh) + if ef.is_parallel(gh): + return ab.is_parallel(cd) + + a, b, c, d = bring_together(a, b, c, d) + e, f, g, h = bring_together(e, f, g, h) + + ba = b - a + dc = d - c + fe = f - e + hg = h - g + + sameclock = (ba.x * dc.y - ba.y * dc.x) * (fe.x * hg.y - fe.y * hg.x) > 0 + if not sameclock: + ba = ba * -1.0 + + a1 = np.arctan2(fe.y, fe.x) + a2 = np.arctan2(hg.y, hg.x) + x = a1 - a2 + + a3 = np.arctan2(ba.y, ba.x) + a4 = np.arctan2(dc.y, dc.x) + y = a3 - a4 + + xy = (x - y) % (2 * np.pi) + return close_enough(xy, 0, tol=1e-11) or close_enough( + xy, 2 * np.pi, tol=1e-11 + ) + + +def check_eqratio(points: list[Point]) -> bool: + a, b, c, d, e, f, g, h = points + ab = a.distance(b) + cd = c.distance(d) + ef = e.distance(f) + gh = g.distance(h) + return close_enough(ab * gh, cd * ef) + + +def check_cong(points: list[Point]) -> bool: + a, b, c, d = points + return close_enough(a.distance(b), c.distance(d)) + + +def check_midp(points: list[Point]) -> bool: + a, b, c = points + return check_coll(points) and close_enough(a.distance(b), a.distance(c)) + + +def check_simtri(points: list[Point]) -> bool: + """Check if 6 points make a pair of similar triangles.""" + a, b, c, x, y, z = points + ab = a.distance(b) + bc = b.distance(c) + ca = c.distance(a) + xy = x.distance(y) + yz = y.distance(z) + zx = z.distance(x) + tol = 1e-9 + return close_enough(ab * yz, bc * xy, tol) and close_enough( + bc * zx, ca * yz, tol + ) + + +def check_contri(points: list[Point]) -> bool: + a, b, c, x, y, z = points + ab = a.distance(b) + bc = b.distance(c) + ca = c.distance(a) + xy = x.distance(y) + yz = y.distance(z) + zx = z.distance(x) + tol = 1e-9 + return ( + close_enough(ab, xy, tol) + and close_enough(bc, yz, tol) + and close_enough(ca, zx, tol) + ) + + +def check_ratio(points: list[Point]) -> bool: + a, b, c, d, m, n = points + ab = a.distance(b) + cd = c.distance(d) + return close_enough(ab * n, cd * m) + + +def draw_angle( + ax: matplotlib.axes.Axes, + head: Point, + p1: Point, + p2: Point, + color: Any = 'red', + alpha: float = 0.5, + frac: float = 1.0, +) -> None: + """Draw an angle on plt ax.""" + d1 = p1 - head + d2 = p2 - head + + a1 = np.arctan2(float(d1.y), float(d1.x)) + a2 = np.arctan2(float(d2.y), float(d2.x)) + a1, a2 = a1 * 180 / np.pi, a2 * 180 / np.pi + a1, a2 = a1 % 360, a2 % 360 + + if a1 > a2: + a1, a2 = a2, a1 + + if a2 - a1 > 180: + a1, a2 = a2, a1 + + b1, b2 = a1, a2 + if b1 > b2: + b2 += 360 + d = b2 - b1 + # if d >= 90: + # return + + scale = min(2.0, 90 / d) + scale = max(scale, 0.4) + fov = matplotlib.patches.Wedge( + (float(head.x), float(head.y)), + unif(0.075, 0.125) * scale * frac, + a1, + a2, + color=color, + alpha=alpha, + ) + ax.add_artist(fov) + + +def naming_position( + ax: matplotlib.axes.Axes, p: Point, lines: list[Line], circles: list[Circle] +) -> tuple[float, float]: + """Figure out a good naming position on the drawing.""" + _ = ax + r = 0.08 + c = Circle(center=p, radius=r) + avoid = [] + for p1, p2 in lines: + try: + avoid.extend(circle_segment_intersect(c, p1, p2)) + except InvalidQuadSolveError: + continue + for x in circles: + try: + avoid.extend(circle_circle_intersection(c, x)) + except InvalidQuadSolveError: + continue + + if not avoid: + return [p.x + 0.01, p.y + 0.01] + + angs = sorted([ang_of(p, a) for a in avoid]) + angs += [angs[0] + 2 * np.pi] + angs = [(angs[i + 1] - a, a) for i, a in enumerate(angs[:-1])] + + d, a = max(angs) + ang = a + d / 2 + + name_pos = p + Point(np.cos(ang), np.sin(ang)) * r + + x, y = (name_pos.x - r / 1.5, name_pos.y - r / 1.5) + return x, y + + +def draw_point( + ax: matplotlib.axes.Axes, + p: Point, + name: str, + lines: list[Line], + circles: list[Circle], + color: Any = 'white', + size: float = 15, +) -> None: + """draw a point.""" + ax.scatter(p.x, p.y, color=color, s=size) + + if color == 'white': + color = 'lightgreen' + else: + color = 'grey' + + name = name.upper() + if len(name) > 1: + name = name[0] + '_' + name[1:] + + ax.annotate( + name, naming_position(ax, p, lines, circles), color=color, fontsize=15 + ) + + +def _draw_line( + ax: matplotlib.axes.Axes, + p1: Point, + p2: Point, + color: Any = 'white', + lw: float = 1.2, + alpha: float = 0.8, +) -> None: + """Draw a line in matplotlib.""" + ls = '-' + if color == '--': + color = 'black' + ls = '--' + + lx, ly = (p1.x, p2.x), (p1.y, p2.y) + ax.plot(lx, ly, color=color, lw=lw, alpha=alpha, ls=ls) + + +def draw_line( + ax: matplotlib.axes.Axes, line: Line, color: Any = 'white' +) -> tuple[Point, Point]: + """Draw a line.""" + points = line.neighbors(gm.Point) + if len(points) <= 1: + return + + points = [p.num for p in points] + p1, p2 = points[:2] + + pmin, pmax = (p1, 0.0), (p2, (p2 - p1).dot(p2 - p1)) + + for p in points[2:]: + v = (p - p1).dot(p2 - p1) + if v < pmin[1]: + pmin = p, v + if v > pmax[1]: + pmax = p, v + + p1, p2 = pmin[0], pmax[0] + _draw_line(ax, p1, p2, color=color) + return p1, p2 + + +def _draw_circle( + ax: matplotlib.axes.Axes, c: Circle, color: Any = 'cyan', lw: float = 1.2 +) -> None: + ls = '-' + if color == '--': + color = 'black' + ls = '--' + + ax.add_patch( + plt.Circle( + (c.center.x, c.center.y), + c.radius, + color=color, + alpha=0.8, + fill=False, + lw=lw, + ls=ls, + ) + ) + + +def draw_circle( + ax: matplotlib.axes.Axes, circle: Circle, color: Any = 'cyan' +) -> Circle: + """Draw a circle.""" + if circle.num is not None: + circle = circle.num + else: + points = circle.neighbors(gm.Point) + if len(points) <= 2: + return + points = [p.num for p in points] + p1, p2, p3 = points[:3] + circle = Circle(p1=p1, p2=p2, p3=p3) + + _draw_circle(ax, circle, color) + return circle + + +def mark_segment( + ax: matplotlib.axes.Axes, p1: Point, p2: Point, color: Any, alpha: float +) -> None: + _ = alpha + x, y = (p1.x + p2.x) / 2, (p1.y + p2.y) / 2 + ax.scatter(x, y, color=color, alpha=1.0, marker='o', s=50) + + +def highlight_angle( + ax: matplotlib.axes.Axes, + a: Point, + b: Point, + c: Point, + d: Point, + color: Any, + alpha: float, +) -> None: + """Highlight an angle between ab and cd with (color, alpha).""" + try: + a, b, c, d = bring_together(a, b, c, d) + except: # pylint: disable=bare-except + return + draw_angle(ax, a, b, d, color=color, alpha=alpha, frac=1.0) + + +def highlight( + ax: matplotlib.axes.Axes, + name: str, + args: list[gm.Point], + lcolor: Any, + color1: Any, + color2: Any, +) -> None: + """Draw highlights.""" + args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args)) + + if name == 'cyclic': + a, b, c, d = args + _draw_circle(ax, Circle(p1=a, p2=b, p3=c), color=color1, lw=2.0) + if name == 'coll': + a, b, c = args + a, b = max(a, b, c), min(a, b, c) + _draw_line(ax, a, b, color=color1, lw=2.0) + if name == 'para': + a, b, c, d = args + _draw_line(ax, a, b, color=color1, lw=2.0) + _draw_line(ax, c, d, color=color2, lw=2.0) + if name == 'eqangle': + a, b, c, d, e, f, g, h = args + + x = line_line_intersection(Line(a, b), Line(c, d)) + if b.distance(x) > a.distance(x): + a, b = b, a + if d.distance(x) > c.distance(x): + c, d = d, c + a, b, d = x, a, c + + y = line_line_intersection(Line(e, f), Line(g, h)) + if f.distance(y) > e.distance(y): + e, f = f, e + if h.distance(y) > g.distance(y): + g, h = h, g + e, f, h = y, e, g + + _draw_line(ax, a, b, color=lcolor, lw=2.0) + _draw_line(ax, a, d, color=lcolor, lw=2.0) + _draw_line(ax, e, f, color=lcolor, lw=2.0) + _draw_line(ax, e, h, color=lcolor, lw=2.0) + if color1 == '--': + color1 = 'red' + draw_angle(ax, a, b, d, color=color1, alpha=0.5) + if color2 == '--': + color2 = 'red' + draw_angle(ax, e, f, h, color=color2, alpha=0.5) + if name == 'perp': + a, b, c, d = args + _draw_line(ax, a, b, color=color1, lw=2.0) + _draw_line(ax, c, d, color=color1, lw=2.0) + if name == 'ratio': + a, b, c, d, m, n = args + _draw_line(ax, a, b, color=color1, lw=2.0) + _draw_line(ax, c, d, color=color2, lw=2.0) + if name == 'cong': + a, b, c, d = args + _draw_line(ax, a, b, color=color1, lw=2.0) + _draw_line(ax, c, d, color=color2, lw=2.0) + if name == 'midp': + m, a, b = args + _draw_line(ax, a, m, color=color1, lw=2.0, alpha=0.5) + _draw_line(ax, b, m, color=color2, lw=2.0, alpha=0.5) + if name == 'eqratio': + a, b, c, d, m, n, p, q = args + _draw_line(ax, a, b, color=color1, lw=2.0, alpha=0.5) + _draw_line(ax, c, d, color=color2, lw=2.0, alpha=0.5) + _draw_line(ax, m, n, color=color1, lw=2.0, alpha=0.5) + _draw_line(ax, p, q, color=color2, lw=2.0, alpha=0.5) + + +HCOLORS = None + + +def _draw( + ax: matplotlib.axes.Axes, + points: list[gm.Point], + lines: list[gm.Line], + circles: list[gm.Circle], + goal: Any, + equals: list[tuple[Any, Any]], + highlights: list[tuple[str, list[gm.Point]]], +): + """Draw everything.""" + colors = ['red', 'green', 'blue', 'orange', 'magenta', 'purple'] + pcolor = 'black' + lcolor = 'black' + ccolor = 'grey' + if get_theme() == 'dark': + pcolor, lcolor, ccolor = 'white', 'white', 'cyan' + elif get_theme() == 'light': + pcolor, lcolor, ccolor = 'black', 'black', 'blue' + elif get_theme() == 'grey': + pcolor, lcolor, ccolor = 'black', 'black', 'grey' + colors = ['grey'] + + line_boundaries = [] + for l in lines: + p1, p2 = draw_line(ax, l, color=lcolor) + line_boundaries.append((p1, p2)) + circles = [draw_circle(ax, c, color=ccolor) for c in circles] + + for p in points: + draw_point(ax, p.num, p.name, line_boundaries, circles, color=pcolor) + + if equals: + for i, segs in enumerate(equals['segments']): + color = colors[i % len(colors)] + for a, b in segs: + mark_segment(ax, a, b, color, 0.5) + + for i, angs in enumerate(equals['angles']): + color = colors[i % len(colors)] + for a, b, c, d in angs: + highlight_angle(ax, a, b, c, d, color, 0.5) + + if highlights: + global HCOLORS + if HCOLORS is None: + HCOLORS = [k for k in mcolors.TABLEAU_COLORS.keys() if 'red' not in k] + + for i, (name, args) in enumerate(highlights): + color_i = HCOLORS[i % len(HCOLORS)] + highlight(ax, name, args, 'black', color_i, color_i) + + if goal: + name, args = goal + lcolor = color1 = color2 = 'red' + highlight(ax, name, args, lcolor, color1, color2) + + +THEME = 'dark' + + +def set_theme(theme) -> None: + global THEME + THEME = theme + + +def get_theme() -> str: + return THEME + + +def draw( + points: list[gm.Point], + lines: list[gm.Line], + circles: list[gm.Circle], + segments: list[gm.Segment], + goal: Any = None, + highlights: list[tuple[str, list[gm.Point]]] = None, + equals=list[tuple[Any, Any]], + block: bool = True, + save_to: str = None, + theme: str = 'dark', +) -> None: + """Draw everything on the same canvas.""" + plt.close() + imsize = 512 / 100 + fig, ax = plt.subplots(figsize=(imsize, imsize), dpi=100) + + set_theme(theme) + + if get_theme() == 'dark': + ax.set_facecolor((0.0, 0.0, 0.0)) + else: + ax.set_facecolor((1.0, 1.0, 1.0)) + + _draw(ax, points, lines, circles, goal, equals, highlights) + + plt.axis('equal') + fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0) + if points: + xmin = min([p.num.x for p in points]) + xmax = max([p.num.x for p in points]) + ymin = min([p.num.y for p in points]) + ymax = max([p.num.y for p in points]) + plt.margins((xmax - xmin) * 0.1, (ymax - ymin) * 0.1) + + plt.show(block=block) + + +def close_enough(a: float, b: float, tol: float = 1e-12) -> bool: + return abs(a - b) < tol + + +def assert_close_enough(a: float, b: float, tol: float = 1e-12) -> None: + assert close_enough(a, b, tol), f'|{a}-{b}| = {abs(a-b)} >= {tol}' + + +def ang_of(tail: Point, head: Point) -> float: + vector = head - tail + arctan = np.arctan2(vector.y, vector.x) % (2 * np.pi) + return arctan + + +def ang_between(tail: Point, head1: Point, head2: Point) -> float: + ang1 = ang_of(tail, head1) + ang2 = ang_of(tail, head2) + diff = ang1 - ang2 + # return diff % (2*np.pi) + if diff > np.pi: + return diff - 2 * np.pi + if diff < -np.pi: + return 2 * np.pi + diff + return diff + + +def head_from(tail: Point, ang: float, length: float = 1) -> Point: + vector = Point(np.cos(ang) * length, np.sin(ang) * length) + return tail + vector + + +def random_points(n: int = 3) -> list[Point]: + return [Point(unif(-1, 1), unif(-1, 1)) for _ in range(n)] + + +def random_rfss(*points: list[Point]) -> list[Point]: + """Random rotate-flip-scale-shift a point cloud.""" + # center point cloud. + average = sum(points, Point(0.0, 0.0)) * (1.0 / len(points)) + points = [p - average for p in points] + + # rotate + ang = unif(0.0, 2 * np.pi) + sin, cos = np.sin(ang), np.cos(ang) + # scale and shift + scale = unif(0.5, 2.0) + shift = Point(unif(-1, 1), unif(-1, 1)) + points = [p.rotate(sin, cos) * scale + shift for p in points] + + # randomly flip + if np.random.rand() < 0.5: + points = [p.flip() for p in points] + + return points + + +def reduce( + objs: list[Union[Point, Line, Circle, HalfLine, HoleCircle]], + existing_points: list[Point], +) -> list[Point]: + """Reduce intersecting objects into one point of intersections.""" + if all(isinstance(o, Point) for o in objs): + return objs + + elif len(objs) == 1: + return objs[0].sample_within(existing_points) + + elif len(objs) == 2: + a, b = objs + result = a.intersect(b) + if isinstance(result, Point): + return [result] + a, b = result + a_close = any([a.close(x) for x in existing_points]) + if a_close: + return [b] + b_close = any([b.close(x) for x in existing_points]) + if b_close: + return [a] + return [np.random.choice([a, b])] + + else: + raise ValueError(f'Cannot reduce {objs}') + + +def sketch( + name: str, args: list[Union[Point, gm.Point]] +) -> list[Union[Point, Line, Circle, HalfLine, HoleCircle]]: + fun = globals()['sketch_' + name] + args = [p.num if isinstance(p, gm.Point) else p for p in args] + out = fun(args) + + # out can be one or multiple {Point/Line/HalfLine} + if isinstance(out, (tuple, list)): + return list(out) + return [out] + + +def sketch_on_opline(args: tuple[gm.Point, ...]) -> HalfLine: + a, b = args + return HalfLine(a, a + a - b) + + +def sketch_on_hline(args: tuple[gm.Point, ...]) -> HalfLine: + a, b = args + return HalfLine(a, b) + + +def sketch_ieq_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(1.0, 0.0) + + c, _ = Circle(a, p1=b).intersect(Circle(b, p1=a)) + return a, b, c + + +def sketch_incenter2(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a, b, c = args + l1 = sketch_bisect([b, a, c]) + l2 = sketch_bisect([a, b, c]) + i = line_line_intersection(l1, l2) + x = i.foot(Line(b, c)) + y = i.foot(Line(c, a)) + z = i.foot(Line(a, b)) + return x, y, z, i + + +def sketch_excenter2(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a, b, c = args + l1 = sketch_bisect([b, a, c]) + l2 = sketch_exbisect([a, b, c]) + i = line_line_intersection(l1, l2) + x = i.foot(Line(b, c)) + y = i.foot(Line(c, a)) + z = i.foot(Line(a, b)) + return x, y, z, i + + +def sketch_centroid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a, b, c = args + x = (b + c) * 0.5 + y = (c + a) * 0.5 + z = (a + b) * 0.5 + i = line_line_intersection(Line(a, x), Line(b, y)) + return x, y, z, i + + +def sketch_ninepoints(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a, b, c = args + x = (b + c) * 0.5 + y = (c + a) * 0.5 + z = (a + b) * 0.5 + c = Circle(p1=x, p2=y, p3=z) + return x, y, z, c.center + + +def sketch_2l1c(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + """Sketch a circle touching two lines and another circle.""" + a, b, c, p = args + bc, ac = Line(b, c), Line(a, c) + circle = Circle(p, p1=a) + + d, d_ = line_circle_intersection(p.perpendicular_line(bc), circle) + if bc.diff_side(d_, a): + d = d_ + + e, e_ = line_circle_intersection(p.perpendicular_line(ac), circle) + if ac.diff_side(e_, b): + e = e_ + + df = d.perpendicular_line(Line(p, d)) + ef = e.perpendicular_line(Line(p, e)) + f = line_line_intersection(df, ef) + + g, g_ = line_circle_intersection(Line(c, f), circle) + if bc.same_side(g_, a): + g = g_ + + b_ = c + (b - c) / b.distance(c) + a_ = c + (a - c) / a.distance(c) + m = (a_ + b_) * 0.5 + x = line_line_intersection(Line(c, m), Line(p, g)) + return x.foot(ac), x.foot(bc), g, x + + +def sketch_3peq(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a, b, c = args + ab, bc, ca = Line(a, b), Line(b, c), Line(c, a) + + z = b + (c - b) * np.random.uniform(-0.5, 1.5) + + z_ = z * 2 - c + l = z_.parallel_line(ca) + x = line_line_intersection(l, ab) + y = z * 2 - x + return x, y, z + + +def try_to_sketch_intersect( + name1: str, + args1: list[Union[gm.Point, Point]], + name2: str, + args2: list[Union[gm.Point, Point]], + existing_points: list[Point], +) -> Optional[Point]: + """Try to sketch an intersection between two objects.""" + obj1 = sketch(name1, args1)[0] + obj2 = sketch(name2, args2)[0] + + if isinstance(obj1, Line) and isinstance(obj2, Line): + fn = line_line_intersection + elif isinstance(obj1, Circle) and isinstance(obj2, Circle): + fn = circle_circle_intersection + else: + fn = line_circle_intersection + if isinstance(obj2, Line) and isinstance(obj1, Circle): + obj1, obj2 = obj2, obj1 + + try: + x = fn(obj1, obj2) + except: # pylint: disable=bare-except + return None + + if isinstance(x, Point): + return x + + x1, x2 = x + + close1 = check_too_close([x1], existing_points) + far1 = check_too_far([x1], existing_points) + if not close1 and not far1: + return x1 + close2 = check_too_close([x2], existing_points) + far2 = check_too_far([x2], existing_points) + if not close2 and not far2: + return x2 + + return None + + +def sketch_acircle(args: tuple[gm.Point, ...]) -> Circle: + a, b, c, d, f = args + de = sketch_aline([c, a, b, f, d]) + fe = sketch_aline([a, c, b, d, f]) + e = line_line_intersection(de, fe) + return Circle(p1=d, p2=e, p3=f) + + +def sketch_aline(args: tuple[gm.Point, ...]) -> HalfLine: + """Sketch the construction aline.""" + A, B, C, D, E = args + ab = A - B + cb = C - B + de = D - E + + dab = A.distance(B) + ang_ab = np.arctan2(ab.y / dab, ab.x / dab) + + dcb = C.distance(B) + ang_bc = np.arctan2(cb.y / dcb, cb.x / dcb) + + dde = D.distance(E) + ang_de = np.arctan2(de.y / dde, de.x / dde) + + ang_ex = ang_de + ang_bc - ang_ab + X = E + Point(np.cos(ang_ex), np.sin(ang_ex)) + return HalfLine(E, X) + + +def sketch_amirror(args: tuple[gm.Point, ...]) -> HalfLine: + """Sketch the angle mirror.""" + A, B, C = args # pylint: disable=invalid-name + ab = A - B + cb = C - B + + dab = A.distance(B) + ang_ab = np.arctan2(ab.y / dab, ab.x / dab) + dcb = C.distance(B) + ang_bc = np.arctan2(cb.y / dcb, cb.x / dcb) + + ang_bx = 2 * ang_bc - ang_ab + X = B + Point(np.cos(ang_bx), np.sin(ang_bx)) # pylint: disable=invalid-name + return HalfLine(B, X) + + +def sketch_bisect(args: tuple[gm.Point, ...]) -> Line: + a, b, c = args + ab = a.distance(b) + bc = b.distance(c) + x = b + (c - b) * (ab / bc) + m = (a + x) * 0.5 + return Line(b, m) + + +def sketch_exbisect(args: tuple[gm.Point, ...]) -> Line: + a, b, c = args + return sketch_bisect(args).perpendicular_line(b) + + +def sketch_bline(args: tuple[gm.Point, ...]) -> Line: + a, b = args + m = (a + b) * 0.5 + return m.perpendicular_line(Line(a, b)) + + +def sketch_dia(args: tuple[gm.Point, ...]) -> Circle: + a, b = args + return Circle((a + b) * 0.5, p1=a) + + +def sketch_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, Point]: + a, o, b = args + dia = sketch_dia([a, o]) + return circle_circle_intersection(Circle(o, p1=b), dia) + + +def sketch_circle(args: tuple[gm.Point, ...]) -> Circle: + a, b, c = args + return Circle(center=a, radius=b.distance(c)) + + +def sketch_cc_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + """Sketch tangents to two circles.""" + o, a, w, b = args + ra, rb = o.distance(a), w.distance(b) + + ow = Line(o, w) + if close_enough(ra, rb): + oo = ow.perpendicular_line(o) + oa = Circle(o, ra) + x, z = line_circle_intersection(oo, oa) + y = x + w - o + t = z + w - o + return x, y, z, t + + swap = rb > ra + if swap: + o, a, w, b = w, b, o, a + ra, rb = rb, ra + + oa = Circle(o, ra) + q = o + (w - o) * ra / (ra - rb) + + x, z = circle_circle_intersection(sketch_dia([o, q]), oa) + y = w.foot(Line(x, q)) + t = w.foot(Line(z, q)) + + if swap: + x, y, z, t = y, x, t, z + + return x, y, z, t + + +def sketch_hcircle(args: tuple[gm.Point, ...]) -> HoleCircle: + a, b = args + return HoleCircle(center=a, radius=a.distance(b), hole=b) + + +def sketch_e5128(args: tuple[gm.Point, ...]) -> tuple[Point, Point]: + a, b, c, d = args + ad = Line(a, d) + + g = (a + b) * 0.5 + de = Line(d, g) + + e, f = line_circle_intersection(de, Circle(c, p1=b)) + + if e.distance(d) < f.distance(d): + e = f + return e, g + + +def sketch_eq_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + """Sketch quadrangle with two equal opposite sides.""" + a = Point(0.0, 0.0) + b = Point(1.0, 0.0) + + length = np.random.uniform(0.5, 2.0) + ang = np.random.uniform(np.pi / 3, np.pi * 2 / 3) + d = head_from(a, ang, length) + + ang = ang_of(b, d) + ang = np.random.uniform(ang / 10, ang / 9) + c = head_from(b, ang, length) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_eq_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(1.0, 0.0) + l = unif(0.5, 2.0) + + height = unif(0.5, 2.0) + c = Point(0.5 + l / 2.0, height) + d = Point(0.5 - l / 2.0, height) + + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_eqangle2(args: tuple[gm.Point, ...]) -> Point: + """Sketch the def eqangle2.""" + a, b, c = args + + d = c * 2 - b + + ba = b.distance(a) + bc = b.distance(c) + l = ba * ba / bc + + if unif(0.0, 1.0) < 0.5: + be = min(l, bc) + be = unif(be * 0.1, be * 0.9) + else: + be = max(l, bc) + be = unif(be * 1.1, be * 1.5) + + e = b + (c - b) * (be / bc) + y = b + (a - b) * (be / l) + return line_line_intersection(Line(c, y), Line(a, e)) + + +def sketch_eqangle3(args: tuple[gm.Point, ...]) -> Circle: + a, b, d, e, f = args + de = d.distance(e) + ef = e.distance(f) + ab = b.distance(a) + ang_ax = ang_of(a, b) + ang_between(e, d, f) + x = head_from(a, ang_ax, length=de / ef * ab) + return Circle(p1=a, p2=b, p3=x) + + +def sketch_eqdia_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + """Sketch quadrangle with two equal diagonals.""" + m = unif(0.3, 0.7) + n = unif(0.3, 0.7) + a = Point(-m, 0.0) + c = Point(1 - m, 0.0) + b = Point(0.0, -n) + d = Point(0.0, 1 - n) + + ang = unif(-0.25 * np.pi, 0.25 * np.pi) + sin, cos = np.sin(ang), np.cos(ang) + b = b.rotate(sin, cos) + d = d.rotate(sin, cos) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_free(args: tuple[gm.Point, ...]) -> Point: + return random_points(1)[0] + + +def sketch_isos(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + base = unif(0.5, 1.5) + height = unif(0.5, 1.5) + + b = Point(-base / 2, 0.0) + c = Point(base / 2, 0.0) + a = Point(0.0, height) + a, b, c = random_rfss(a, b, c) + return a, b, c + + +def sketch_line(args: tuple[gm.Point, ...]) -> Line: + a, b = args + return Line(a, b) + + +def sketch_cyclic(args: tuple[gm.Point, ...]) -> Circle: + a, b, c = args + return Circle(p1=a, p2=b, p3=c) + + +def sketch_hline(args: tuple[gm.Point, ...]) -> HalfLine: + a, b = args + return HalfLine(a, b) + + +def sketch_midp(args: tuple[gm.Point, ...]) -> Point: + a, b = args + return (a + b) * 0.5 + + +def sketch_pentagon(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + points = [Point(1.0, 0.0)] + ang = 0.0 + + for i in range(4): + ang += (2 * np.pi - ang) / (5 - i) * unif(0.5, 1.5) + point = Point(np.cos(ang), np.sin(ang)) + points.append(point) + + a, b, c, d, e = points # pylint: disable=unbalanced-tuple-unpacking + a, b, c, d, e = random_rfss(a, b, c, d, e) + return a, b, c, d, e + + +def sketch_pline(args: tuple[gm.Point, ...]) -> Line: + a, b, c = args + return a.parallel_line(Line(b, c)) + + +def sketch_pmirror(args: tuple[gm.Point, ...]) -> Point: + a, b = args + return b * 2 - a + + +def sketch_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + """Sketch a random quadrangle.""" + m = unif(0.3, 0.7) + n = unif(0.3, 0.7) + + a = Point(-m, 0.0) + c = Point(1 - m, 0.0) + b = Point(0.0, -unif(0.25, 0.75)) + d = Point(0.0, unif(0.25, 0.75)) + + ang = unif(-0.25 * np.pi, 0.25 * np.pi) + sin, cos = np.sin(ang), np.cos(ang) + b = b.rotate(sin, cos) + d = d.rotate(sin, cos) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_r_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 1.0) + d = Point(0.0, 0.0) + b = Point(unif(0.5, 1.5), 1.0) + c = Point(unif(0.5, 1.5), 0.0) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_r_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(0.0, unif(0.5, 2.0)) + c = Point(unif(0.5, 2.0), 0.0) + a, b, c = random_rfss(a, b, c) + return a, b, c + + +def sketch_rectangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(0.0, 1.0) + l = unif(0.5, 2.0) + c = Point(l, 1.0) + d = Point(l, 0.0) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_reflect(args: tuple[gm.Point, ...]) -> Point: + a, b, c = args + m = a.foot(Line(b, c)) + return m * 2 - a + + +def sketch_risos(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(0.0, 1.0) + c = Point(1.0, 0.0) + a, b, c = random_rfss(a, b, c) + return a, b, c + + +def sketch_rotaten90(args: tuple[gm.Point, ...]) -> Point: + a, b = args + ang = -np.pi / 2 + return a + (b - a).rotate(np.sin(ang), np.cos(ang)) + + +def sketch_rotatep90(args: tuple[gm.Point, ...]) -> Point: + a, b = args + ang = np.pi / 2 + return a + (b - a).rotate(np.sin(ang), np.cos(ang)) + + +def sketch_s_angle(args: tuple[gm.Point, ...]) -> HalfLine: + a, b, y = args + ang = y / 180 * np.pi + x = b + (a - b).rotatea(ang) + return HalfLine(b, x) + + +def sketch_segment(args: tuple[gm.Point, ...]) -> tuple[Point, Point]: + a, b = random_points(2) + return a, b + + +def sketch_shift(args: tuple[gm.Point, ...]) -> Point: + a, b, c = args + return c + (b - a) + + +def sketch_square(args: tuple[gm.Point, ...]) -> tuple[Point, Point]: + a, b = args + c = b + (a - b).rotatea(-np.pi / 2) + d = a + (b - a).rotatea(np.pi / 2) + return c, d + + +def sketch_isquare(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(1.0, 0.0) + c = Point(1.0, 1.0) + d = Point(0.0, 1.0) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_tline(args: tuple[gm.Point, ...]) -> Line: + a, b, c = args + return a.perpendicular_line(Line(b, c)) + + +def sketch_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + d = Point(0.0, 0.0) + c = Point(1.0, 0.0) + + base = unif(0.5, 2.0) + height = unif(0.5, 2.0) + a = Point(unif(0.2, 0.5), height) + b = Point(a.x + base, height) + a, b, c, d = random_rfss(a, b, c, d) + return a, b, c, d + + +def sketch_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + a = Point(0.0, 0.0) + b = Point(1.0, 0.0) + ac = unif(0.5, 2.0) + ang = unif(0.2, 0.8) * np.pi + c = head_from(a, ang, ac) + return a, b, c + + +def sketch_triangle12(args: tuple[gm.Point, ...]) -> tuple[Point, ...]: + b = Point(0.0, 0.0) + c = Point(unif(1.5, 2.5), 0.0) + a, _ = circle_circle_intersection(Circle(b, 1.0), Circle(c, 2.0)) + a, b, c = random_rfss(a, b, c) + return a, b, c + + +def sketch_trisect(args: tuple[gm.Point, ...]) -> tuple[Point, Point]: + """Sketch two trisectors of an angle.""" + a, b, c = args + ang1 = ang_of(b, a) + ang2 = ang_of(b, c) + + swap = 0 + if ang1 > ang2: + ang1, ang2 = ang2, ang1 + swap += 1 + + if ang2 - ang1 > np.pi: + ang1, ang2 = ang2, ang1 + 2 * np.pi + swap += 1 + + angx = ang1 + (ang2 - ang1) / 3 + angy = ang2 - (ang2 - ang1) / 3 + + x = b + Point(np.cos(angx), np.sin(angx)) + y = b + Point(np.cos(angy), np.sin(angy)) + + ac = Line(a, c) + x = line_line_intersection(Line(b, x), ac) + y = line_line_intersection(Line(b, y), ac) + + if swap == 1: + return y, x + return x, y + + +def sketch_trisegment(args: tuple[gm.Point, ...]) -> tuple[Point, Point]: + a, b = args + x, y = a + (b - a) * (1.0 / 3), a + (b - a) * (2.0 / 3) + return x, y diff --git a/numericals_test.py b/numericals_test.py new file mode 100644 index 0000000..96b0894 --- /dev/null +++ b/numericals_test.py @@ -0,0 +1,313 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit testing for the geometry numericals code.""" + +import unittest + +from absl.testing import absltest +import numericals as nm + +np = nm.np + +unif = nm.unif +Point = nm.Point +Line = nm.Line +Circle = nm.Circle +HalfLine = nm.HalfLine + +line_circle_intersection = nm.line_circle_intersection +line_line_intersection = nm.line_line_intersection + +check_coll = nm.check_coll +check_eqangle = nm.check_eqangle + +random_points = nm.random_points +ang_between = nm.ang_between +head_from = nm.head_from + + +class NumericalTest(unittest.TestCase): + + def test_sketch_ieq_triangle(self): + a, b, c = nm.sketch_ieq_triangle([]) + self.assertAlmostEqual(a.distance(b), b.distance(c)) + self.assertAlmostEqual(c.distance(a), b.distance(c)) + + def test_sketch_2l1c(self): + p = nm.Point(0.0, 0.0) + pi = np.pi + anga = unif(-0.4 * pi, 0.4 * pi) + a = Point(np.cos(anga), np.sin(anga)) + angb = unif(0.6 * pi, 1.4 * pi) + b = Point(np.cos(angb), np.sin(angb)) + + angc = unif(anga + 0.05 * pi, angb - 0.05 * pi) + c = Point(np.cos(angc), np.sin(angc)) * unif(0.2, 0.8) + + x, y, z, i = nm.sketch_2l1c([a, b, c, p]) + self.assertTrue(check_coll([x, c, a])) + self.assertTrue(check_coll([y, c, b])) + self.assertAlmostEqual(z.distance(p), 1.0) + self.assertTrue(check_coll([p, i, z])) + self.assertTrue(Line(i, x).is_perp(Line(c, a))) + self.assertTrue(Line(i, y).is_perp(Line(c, b))) + self.assertAlmostEqual(i.distance(x), i.distance(y)) + self.assertAlmostEqual(i.distance(x), i.distance(z)) + + def test_sketch_3peq(self): + a, b, c = random_points(3) + x, y, z = nm.sketch_3peq([a, b, c]) + + self.assertTrue(check_coll([a, b, x])) + self.assertTrue(check_coll([a, c, y])) + self.assertTrue(check_coll([b, c, z])) + self.assertTrue(check_coll([x, y, z])) + self.assertAlmostEqual(z.distance(x), z.distance(y)) + + def test_sketch_aline(self): + a, b, c, d, e = random_points(5) + ex = nm.sketch_aline([a, b, c, d, e]) + self.assertIsInstance(ex, HalfLine) + self.assertEqual(ex.tail, e) + x = ex.head + self.assertAlmostEqual(ang_between(b, a, c), ang_between(e, d, x)) + + def test_sketch_amirror(self): + a, b, c = random_points(3) + bx = nm.sketch_amirror([a, b, c]) + self.assertIsInstance(bx, HalfLine) + assert bx.tail == b + x = bx.head + + ang1 = ang_between(b, a, c) + ang2 = ang_between(b, c, x) + self.assertAlmostEqual(ang1, ang2) + + def test_sketch_bisect(self): + a, b, c = random_points(3) + line = nm.sketch_bisect([a, b, c]) + self.assertAlmostEqual(b.distance(line), 0.0) + + l = a.perpendicular_line(line) + x = line_line_intersection(l, Line(b, c)) + self.assertAlmostEqual(a.distance(line), x.distance(line)) + + d, _ = line_circle_intersection(line, Circle(b, radius=1)) + ang1 = ang_between(b, a, d) + ang2 = ang_between(b, d, c) + self.assertAlmostEqual(ang1, ang2) + + def test_sketch_bline(self): + a, b = random_points(2) + l = nm.sketch_bline([a, b]) + self.assertTrue(Line(a, b).is_perp(l)) + self.assertAlmostEqual(a.distance(l), b.distance(l)) + + def test_sketch_cc_tangent(self): + o = Point(0.0, 0.0) + w = Point(1.0, 0.0) + + ra = unif(0.0, 0.6) + rb = unif(0.4, 1.0) + + a = unif(0.0, np.pi) + b = unif(0.0, np.pi) + + a = o + ra * Point(np.cos(a), np.sin(a)) + b = w + rb * Point(np.sin(b), np.cos(b)) + + x, y, z, t = nm.sketch_cc_tangent([o, a, w, b]) + xy = Line(x, y) + zt = Line(z, t) + self.assertAlmostEqual(o.distance(xy), o.distance(a)) + self.assertAlmostEqual(o.distance(zt), o.distance(a)) + self.assertAlmostEqual(w.distance(xy), w.distance(b)) + self.assertAlmostEqual(w.distance(zt), w.distance(b)) + + def test_sketch_circle(self): + a, b, c = random_points(3) + circle = nm.sketch_circle([a, b, c]) + self.assertAlmostEqual(circle.center.distance(a), 0.0) + self.assertAlmostEqual(circle.radius, b.distance(c)) + + def test_sketch_e5128(self): + b = Point(0.0, 0.0) + c = Point(0.0, 1.0) + ang = unif(-np.pi / 2, 3 * np.pi / 2) + d = head_from(c, ang, 1.0) + a = Point(unif(0.5, 2.0), 0.0) + + e, g = nm.sketch_e5128([a, b, c, d]) + ang1 = ang_between(a, b, d) + ang2 = ang_between(e, a, g) + self.assertAlmostEqual(ang1, ang2) + + def test_sketch_eq_quadrangle(self): + a, b, c, d = nm.sketch_eq_quadrangle([]) + self.assertAlmostEqual(a.distance(d), c.distance(b)) + ac = Line(a, c) + assert ac.diff_side(b, d), (ac(b), ac(d)) + bd = Line(b, d) + assert bd.diff_side(a, c), (bd(a), bd(c)) + + def test_sketch_eq_trapezoid(self): + a, b, c, d = nm.sketch_eq_trapezoid([]) + assert Line(a, b).is_parallel(Line(c, d)) + self.assertAlmostEqual(a.distance(d), b.distance(c)) + + def test_sketch_eqangle3(self): + points = random_points(5) + x = nm.sketch_eqangle3(points).sample_within(points)[0] + a, b, d, e, f = points + self.assertTrue(check_eqangle([x, a, x, b, d, e, d, f])) + + def test_sketch_eqangle2(self): + a, b, c = random_points(3) + x = nm.sketch_eqangle2([a, b, c]) + ang1 = ang_between(a, b, x) + ang2 = ang_between(c, x, b) + self.assertAlmostEqual(ang1, ang2) + + def test_sketch_edia_quadrangle(self): + a, b, c, d = nm.sketch_eqdia_quadrangle([]) + assert Line(a, c).diff_side(b, d) + assert Line(b, d).diff_side(a, c) + self.assertAlmostEqual(a.distance(c), b.distance(d)) + + def test_sketch_isos(self): + a, b, c = nm.sketch_isos([]) + self.assertAlmostEqual(a.distance(b), a.distance(c)) + self.assertAlmostEqual(ang_between(b, a, c), ang_between(c, b, a)) + + def test_sketch_quadrange(self): + a, b, c, d = nm.sketch_quadrangle([]) + self.assertTrue(Line(a, c).diff_side(b, d)) + self.assertTrue(Line(b, d).diff_side(a, c)) + + def test_sketch_r_trapezoid(self): + a, b, c, d = nm.sketch_r_trapezoid([]) + self.assertTrue(Line(a, b).is_perp(Line(a, d))) + self.assertTrue(Line(a, b).is_parallel(Line(c, d))) + self.assertTrue(Line(a, c).diff_side(b, d)) + self.assertTrue(Line(b, d).diff_side(a, c)) + + def test_sketch_r_triangle(self): + a, b, c = nm.sketch_r_triangle([]) + self.assertTrue(Line(a, b).is_perp(Line(a, c))) + + def test_sketch_rectangle(self): + a, b, c, d = nm.sketch_rectangle([]) + self.assertTrue(Line(a, b).is_perp(Line(b, c))) + self.assertTrue(Line(b, c).is_perp(Line(c, d))) + self.assertTrue(Line(c, d).is_perp(Line(d, a))) + + def test_sketch_reflect(self): + a, b, c = random_points(3) + x = nm.sketch_reflect([a, b, c]) + self.assertTrue(Line(a, x).is_perp(Line(b, c))) + self.assertAlmostEqual(x.distance(Line(b, c)), a.distance(Line(b, c))) + + def test_sketch_risos(self): + a, b, c = nm.sketch_risos([]) + self.assertAlmostEqual(a.distance(b), a.distance(c)) + self.assertTrue(Line(a, b).is_perp(Line(a, c))) + + def test_sketch_rotaten90(self): + a, b = random_points(2) + x = nm.sketch_rotaten90([a, b]) + self.assertAlmostEqual(a.distance(x), a.distance(b)) + self.assertTrue(Line(a, x).is_perp(Line(a, b))) + d = Point(0.0, 0.0) + e = Point(0.0, 1.0) + f = Point(1.0, 0.0) + self.assertAlmostEqual(ang_between(d, e, f), ang_between(a, b, x)) + + def test_sketch_rotatep90(self): + a, b = random_points(2) + x = nm.sketch_rotatep90([a, b]) + self.assertAlmostEqual(a.distance(x), a.distance(b)) + self.assertTrue(Line(a, x).is_perp(Line(a, b))) + d = Point(0.0, 0.0) + e = Point(0.0, 1.0) + f = Point(1.0, 0.0) + self.assertAlmostEqual(ang_between(d, f, e), ang_between(a, b, x)) + + def test_sketch_s_angle(self): + a, b = random_points(2) + y = unif(0.0, np.pi) + bx = nm.sketch_s_angle([a, b, y / np.pi * 180]) + self.assertIsInstance(bx, HalfLine) + self.assertEqual(bx.tail, b) + x = bx.head + + d = Point(1.0, 0.0) + e = Point(0.0, 0.0) + f = Point(np.cos(y), np.sin(y)) + self.assertAlmostEqual(ang_between(e, d, f), ang_between(b, a, x)) + + def test_sketch_shift(self): + a, b, c = random_points(3) + x = nm.sketch_shift([a, b, c]) + self.assertTrue((b - a).close(x - c)) + + def test_sketch_square(self): + a, b = random_points(2) + c, d = nm.sketch_square([a, b]) + self.assertTrue(Line(a, b).is_perp(Line(b, c))) + self.assertTrue(Line(b, c).is_perp(Line(c, d))) + self.assertTrue(Line(c, d).is_perp(Line(d, a))) + self.assertAlmostEqual(a.distance(b), b.distance(c)) + + def test_sketch_isquare(self): + a, b, c, d = nm.sketch_isquare([]) + self.assertTrue(Line(a, b).is_perp(Line(b, c))) + self.assertTrue(Line(b, c).is_perp(Line(c, d))) + self.assertTrue(Line(c, d).is_perp(Line(d, a))) + self.assertAlmostEqual(a.distance(b), b.distance(c)) + + def test_sketch_trapezoid(self): + a, b, c, d = nm.sketch_trapezoid([]) + self.assertTrue(Line(a, b).is_parallel(Line(c, d))) + self.assertTrue(Line(a, c).diff_side(b, d)) + self.assertTrue(Line(b, d).diff_side(a, c)) + + def test_sketch_triangle(self): + a, b, c = nm.sketch_triangle([]) + self.assertFalse(check_coll([a, b, c])) + + def test_sketch_triangle12(self): + a, b, c = nm.sketch_triangle12([]) + self.assertAlmostEqual(a.distance(b) * 2, a.distance(c)) + + def test_sketch_trisect(self): + a, b, c = random_points(3) + x, y = nm.sketch_trisect([a, b, c]) + self.assertAlmostEqual(ang_between(b, a, x), ang_between(b, x, y)) + self.assertAlmostEqual(ang_between(b, x, y), ang_between(b, y, c)) + self.assertAlmostEqual(ang_between(b, a, x) * 3, ang_between(b, a, c)) + + def test_sketch_trisegment(self): + a, b = random_points(2) + x, y = nm.sketch_trisegment([a, b]) + self.assertAlmostEqual( + a.distance(x) + x.distance(y) + y.distance(b), a.distance(b) + ) + self.assertAlmostEqual(a.distance(x), x.distance(y)) + self.assertAlmostEqual(x.distance(y), y.distance(b)) + + +if __name__ == '__main__': + absltest.main() diff --git a/pretty.py b/pretty.py new file mode 100644 index 0000000..d794fd8 --- /dev/null +++ b/pretty.py @@ -0,0 +1,216 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities for string manipulation in the DSL.""" + +MAP_SYMBOL = { + 'T': 'perp', + 'P': 'para', + 'D': 'cong', + 'S': 'simtri', + 'I': 'circle', + 'M': 'midp', + 'O': 'cyclic', + 'C': 'coll', + '^': 'eqangle', + '/': 'eqratio', + '%': 'eqratio', + '=': 'contri', + 'X': 'collx', + 'A': 'acompute', + 'R': 'rcompute', + 'Q': 'fixc', + 'E': 'fixl', + 'V': 'fixb', + 'H': 'fixt', + 'Z': 'fixp', + 'Y': 'ind', +} + + +def map_symbol(c: str) -> str: + return MAP_SYMBOL[c] + + +def map_symbol_inv(c: str) -> str: + return {v: k for k, v in MAP_SYMBOL.items()}[c] + + +def _gcd(x: int, y: int) -> int: + while y: + x, y = y, x % y + return x + + +def simplify(n: int, d: int) -> tuple[int, int]: + g = _gcd(n, d) + return (n // g, d // g) + + +def pretty2r(a: str, b: str, c: str, d: str) -> str: + if b in (c, d): + a, b = b, a + + if a == d: + c, d = d, c + + return f'{a} {b} {c} {d}' + + +def pretty2a(a: str, b: str, c: str, d: str) -> str: + if b in (c, d): + a, b = b, a + + if a == d: + c, d = d, c + + return f'{a} {b} {c} {d}' + + +def pretty_angle(a: str, b: str, c: str, d: str) -> str: + if b in (c, d): + a, b = b, a + if a == d: + c, d = d, c + + if a == c: + return f'\u2220{b}{a}{d}' + return f'\u2220({a}{b}-{c}{d})' + + +def pretty_nl(name: str, args: list[str]) -> str: + """Natural lang formatting a predicate.""" + if name == 'aconst': + a, b, c, d, y = args + return f'{pretty_angle(a, b, c, d)} = {y}' + if name == 'rconst': + a, b, c, d, y = args + return f'{a}{b}:{c}{d} = {y}' + if name == 'acompute': + a, b, c, d = args + return f'{pretty_angle(a, b, c, d)}' + if name in ['coll', 'C']: + return '' + ','.join(args) + ' are collinear' + if name == 'collx': + return '' + ','.join(list(set(args))) + ' are collinear' + if name in ['cyclic', 'O']: + return '' + ','.join(args) + ' are concyclic' + if name in ['midp', 'midpoint', 'M']: + x, a, b = args + return f'{x} is midpoint of {a}{b}' + if name in ['eqangle', 'eqangle6', '^']: + a, b, c, d, e, f, g, h = args + return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}' + if name in ['eqratio', 'eqratio6', '/']: + return '{}{}:{}{} = {}{}:{}{}'.format(*args) + if name == 'eqratio3': + a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name + return f'S {o} {a} {b} {o} {c} {d}' + if name in ['cong', 'D']: + a, b, c, d = args + return f'{a}{b} = {c}{d}' + if name in ['perp', 'T']: + if len(args) == 2: # this is algebraic derivation. + ab, cd = args # ab = 'd( ... )' + return f'{ab} \u27c2 {cd}' + a, b, c, d = args + return f'{a}{b} \u27c2 {c}{d}' + if name in ['para', 'P']: + if len(args) == 2: # this is algebraic derivation. + ab, cd = args # ab = 'd( ... )' + return f'{ab} \u2225 {cd}' + a, b, c, d = args + return f'{a}{b} \u2225 {c}{d}' + if name in ['simtri2', 'simtri', 'simtri*']: + a, b, c, x, y, z = args + return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}' + if name in ['contri2', 'contri', 'contri*']: + a, b, c, x, y, z = args + return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}' + if name in ['circle', 'I']: + o, a, b, c = args + return f'{o} is the circumcenter of \\Delta {a}{b}{c}' + if name == 'foot': + a, b, c, d = args + return f'{a} is the foot of {b} on {c}{d}' + + +def pretty(txt: str) -> str: + """Pretty formating a predicate string.""" + if isinstance(txt, str): + txt = txt.split(' ') + name, *args = txt + if name == 'ind': + return 'Y ' + ' '.join(args) + if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']: + return map_symbol_inv(name) + ' ' + ' '.join(args) + if name == 'acompute': + a, b, c, d = args + return 'A ' + ' '.join(args) + if name == 'rcompute': + a, b, c, d = args + return 'R ' + ' '.join(args) + if name == 'aconst': + a, b, c, d, y = args + return f'^ {pretty2a(a, b, c, d)} {y}' + if name == 'rconst': + a, b, c, d, y = args + return f'/ {pretty2r(a, b, c, d)} {y}' + if name == 'coll': + return 'C ' + ' '.join(args) + if name == 'collx': + return 'X ' + ' '.join(args) + if name == 'cyclic': + return 'O ' + ' '.join(args) + if name in ['midp', 'midpoint']: + x, a, b = args + return f'M {x} {a} {b}' + if name == 'eqangle': + a, b, c, d, e, f, g, h = args + return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}' + if name == 'eqratio': + a, b, c, d, e, f, g, h = args + return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}' + if name == 'eqratio3': + a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name + return f'S {o} {a} {b} {o} {c} {d}' + if name == 'cong': + a, b, c, d = args + return f'D {a} {b} {c} {d}' + if name == 'perp': + if len(args) == 2: # this is algebraic derivation. + ab, cd = args # ab = 'd( ... )' + return f'T {ab} {cd}' + a, b, c, d = args + return f'T {a} {b} {c} {d}' + if name == 'para': + if len(args) == 2: # this is algebraic derivation. + ab, cd = args # ab = 'd( ... )' + return f'P {ab} {cd}' + a, b, c, d = args + return f'P {a} {b} {c} {d}' + if name in ['simtri2', 'simtri', 'simtri*']: + a, b, c, x, y, z = args + return f'S {a} {b} {c} {x} {y} {z}' + if name in ['contri2', 'contri', 'contri*']: + a, b, c, x, y, z = args + return f'= {a} {b} {c} {x} {y} {z}' + if name == 'circle': + o, a, b, c = args + return f'I {o} {a} {b} {c}' + if name == 'foot': + a, b, c, d = args + return f'F {a} {b} {c} {d}' + return ' '.join(txt) diff --git a/problem.py b/problem.py new file mode 100644 index 0000000..9108837 --- /dev/null +++ b/problem.py @@ -0,0 +1,1133 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements objects to represent problems, theorems, proofs, traceback.""" + +from __future__ import annotations + +from collections import defaultdict # pylint: disable=g-importing-member +from typing import Any + +import geometry as gm +import pretty as pt + + +# pylint: disable=protected-access +# pylint: disable=unused-variable +# pylint: disable=unused-argument +# pylint: disable=unused-assignment + + +def reshape(l: list[Any], n: int = 1) -> list[list[Any]]: + assert len(l) % n == 0 + columns = [[] for i in range(n)] + for i, x in enumerate(l): + columns[i % n].append(x) + return zip(*columns) + + +def isint(x: str) -> bool: + try: + int(x) + return True + except: # pylint: disable=bare-except + return False + + +class Construction: + """One predicate.""" + + @classmethod + def from_txt(cls, data: str) -> Construction: + data = data.split(' ') + return Construction(data[0], data[1:]) + + def __init__(self, name: str, args: list[str]): + self.name = name + self.args = args + + def translate(self, mapping: dict[str, str]) -> Construction: + args = [a if isint(a) else mapping[a] for a in self.args] + return Construction(self.name, args) + + def txt(self) -> str: + return ' '.join([self.name] + list(self.args)) + + +class Clause: + """One construction (>= 1 predicate).""" + + @classmethod + def from_txt(cls, data: str) -> Clause: + if data == ' =': + return Clause([], []) + points, constructions = data.split(' = ') + return Clause( + points.split(' '), + [Construction.from_txt(c) for c in constructions.split(', ')], + ) + + def __init__(self, points: list[str], constructions: list[Construction]): + self.points = [] + self.nums = [] + + for p in points: + num = None + if isinstance(p, str) and '@' in p: + p, num = p.split('@') + x, y = num.split('_') + num = float(x), float(y) + self.points.append(p) + self.nums.append(num) + + self.constructions = constructions + + def translate(self, mapping: dict[str, str]) -> Clause: + points0 = [] + for p in self.points: + pcount = len(mapping) + 1 + name = chr(96 + pcount) + if name > 'z': # pcount = 26 -> name = 'z' + name = chr(97 + (pcount - 1) % 26) + str((pcount - 1) // 26) + + p0 = mapping.get(p, name) + mapping[p] = p0 + points0.append(p0) + return Clause(points0, [c.translate(mapping) for c in self.constructions]) + + def add(self, name: str, args: list[str]) -> None: + self.constructions.append(Construction(name, args)) + + def txt(self) -> str: + return ( + ' '.join(self.points) + + ' = ' + + ', '.join(c.txt() for c in self.constructions) + ) + + +def _gcd(x: int, y: int) -> int: + while y: + x, y = y, x % y + return x + + +def simplify(n: int, d: int) -> tuple[int, int]: + g = _gcd(n, d) + return (n // g, d // g) + + +def compare_fn(dep: Dependency) -> tuple[Dependency, str]: + return (dep, pt.pretty(dep)) + + +def sort_deps(deps: list[Dependency]) -> list[Dependency]: + return sorted(deps, key=compare_fn) + + +class Problem: + """Describe one problem to solve.""" + + @classmethod + def from_txt_file( + cls, fname: str, to_dict: bool = False, translate: bool = True + ): + """Load a problem from a text file.""" + with open(fname, 'r') as f: + lines = f.read().split('\n') + + lines = [l for l in lines if l] + data = [ + cls.from_txt(url + '\n' + problem, translate) + for (url, problem) in reshape(lines, 2) + ] + if to_dict: + return cls.to_dict(data) + return data + + @classmethod + def from_txt(cls, data: str, translate: bool = True) -> Problem: + """Load a problem from a str object.""" + url = '' + if '\n' in data: + url, data = data.split('\n') + + if ' ? ' in data: + clauses, goal = data.split(' ? ') + goal = Construction.from_txt(goal) + else: + clauses, goal = data, None + + clauses = clauses.split('; ') + problem = Problem( + url=url, clauses=[Clause.from_txt(c) for c in clauses], goal=goal + ) + if translate: + return problem.translate() + return problem + + @classmethod + def to_dict(cls, data: list[Problem]) -> dict[str, Problem]: + return {p.url: p for p in data} + + def __init__(self, url: str, clauses: list[Clause], goal: Construction): + self.url = url + self.clauses = clauses + self.goal = goal + + def copy(self) -> Problem: + return Problem(self.url, list(self.clauses), self.goal) + + def translate(self) -> Problem: # to single-char point names + """Translate point names into alphabetical.""" + mapping = {} + clauses = [] + + for clause in self.clauses: + clauses.append(clause.translate(mapping)) + + if self.goal: + goal = self.goal.translate(mapping) + else: + goal = self.goal + + p = Problem(self.url, clauses, goal) + p.mapping = mapping + return p + + def txt(self) -> str: + return ( + '; '.join([c.txt() for c in self.clauses]) + ' ? ' + self.goal.txt() + if self.goal + else '' + ) + + def setup_str_from_problem(self, definitions: list[Definition]) -> str: + """Construct the string from Problem object.""" + ref = 0 + + string = [] + for clause in self.clauses: + group = {} + p2deps = defaultdict(list) + for c in clause.constructions: + cdef = definitions[c.name] + + if len(c.args) != len(cdef.construction.args): + assert len(c.args) + len(clause.points) == len(cdef.construction.args) + c.args = clause.points + c.args + + mapping = dict(zip(cdef.construction.args, c.args)) + for points, bs in cdef.basics: + points = tuple([mapping[x] for x in points]) + for p in points: + group[p] = points + + for b in bs: + args = [mapping[a] for a in b.args] + name = b.name + if b.name in ['s_angle', 'aconst']: + x, y, z, v = args + name = 'aconst' + v = int(v) + + if v < 0: + v = -v + x, z = z, x + + m, n = simplify(int(v), 180) + args = [y, z, y, x, f'{m}pi/{n}'] + + p2deps[points].append(hashed_txt(name, args)) + + for k, v in p2deps.items(): + p2deps[k] = sort_deps(v) + + points = clause.points + while points: + p = points[0] + gr = group[p] + points = [x for x in points if x not in gr] + + deps_str = [] + for dep in p2deps[gr]: + ref_str = '{:02}'.format(ref) + dep_str = pt.pretty(dep) + + if dep[0] == 'aconst': + m, n = map(int, dep[-1].split('pi/')) + mn = f'{m}. pi / {n}.' + dep_str = ' '.join(dep_str.split()[:-1] + [mn]) + + deps_str.append(dep_str + ' ' + ref_str) + ref += 1 + + string.append(' '.join(gr) + ' : ' + ' '.join(deps_str)) + + string = '{S} ' + ' ; '.join([s.strip() for s in string]) + goal = self.goal + string += ' ? ' + pt.pretty([goal.name] + goal.args) + return string + + +def parse_rely(s: str) -> dict[str, str]: + result = {} + if not s: + return result + s = [x.strip() for x in s.split(',')] + for x in s: + a, b = x.split(':') + a, b = a.strip().split(), b.strip().split() + result.update({m: b for m in a}) + return result + + +class Definition: + """Definitions of construction statements.""" + + @classmethod + def from_txt_file(cls, fname: str, to_dict: bool = False) -> Definition: + with open(fname, 'r') as f: + lines = f.read() + return cls.from_string(lines, to_dict) + + @classmethod + def from_string(cls, string: str, to_dict: bool = False) -> Definition: + lines = string.split('\n') + data = [cls.from_txt('\n'.join(group)) for group in reshape(lines, 6)] + if to_dict: + return cls.to_dict(data) + return data + + @classmethod + def to_dict(cls, data: list[Definition]) -> dict[str, Definition]: + return {d.construction.name: d for d in data} + + @classmethod + def from_txt(cls, data: str) -> Definition: + """Load definitions from a str object.""" + construction, rely, deps, basics, numerics, _ = data.split('\n') + basics = [] if not basics else [b.strip() for b in basics.split(';')] + + levels = [] + for bs in basics: + if ':' in bs: + points, bs = bs.split(':') + points = points.strip().split() + else: + points = [] + if bs.strip(): + bs = [Construction.from_txt(b.strip()) for b in bs.strip().split(',')] + else: + bs = [] + levels.append((points, bs)) + + numerics = [] if not numerics else numerics.split(', ') + + return Definition( + construction=Construction.from_txt(construction), + rely=parse_rely(rely), + deps=Clause.from_txt(deps), + basics=levels, + numerics=[Construction.from_txt(c) for c in numerics], + ) + + def __init__( + self, + construction: Construction, + rely: dict[str, str], + deps: Clause, + basics: list[tuple[list[str], list[Construction]]], + numerics: list[Construction], + ): + self.construction = construction + self.rely = rely + self.deps = deps + self.basics = basics + self.numerics = numerics + + args = set() + for num in numerics: + args.update(num.args) + + self.points = [] + self.args = [] + for p in self.construction.args: + if p in args: + self.args.append(p) + else: + self.points.append(p) + + +class Theorem: + """Deduction rule.""" + + @classmethod + def from_txt_file(cls, fname: str, to_dict: bool = False) -> Theorem: + with open(fname, 'r') as f: + theorems = f.read() + return cls.from_string(theorems, to_dict) + + @classmethod + def from_string(cls, string: str, to_dict: bool = False) -> Theorem: + """Load deduction rule from a str object.""" + theorems = string.split('\n') + theorems = [l for l in theorems if l and not l.startswith('#')] + theorems = [cls.from_txt(l) for l in theorems] + + for i, th in enumerate(theorems): + th.rule_name = 'r{:02}'.format(i) + + if to_dict: + result = {} + for t in theorems: + if t.name in result: + t.name += '_' + result[t.rule_name] = t + + return result + + return theorems + + @classmethod + def from_txt(cls, data: str) -> Theorem: + premises, conclusion = data.split(' => ') + premises = premises.split(', ') + conclusion = conclusion.split(', ') + return Theorem( + premise=[Construction.from_txt(p) for p in premises], + conclusion=[Construction.from_txt(c) for c in conclusion], + ) + + def __init__( + self, premise: list[Construction], conclusion: list[Construction] + ): + if len(conclusion) != 1: + raise ValueError('Cannot have more than one conclusion') + self.name = '_'.join([p.name for p in premise + conclusion]) + self.premise = premise + self.conclusion = conclusion + self.is_arg_reduce = False + + assert len(self.conclusion) == 1 + con = self.conclusion[0] + + if con.name in [ + 'eqratio3', + 'midp', + 'contri', + 'simtri', + 'contri2', + 'simtri2', + 'simtri*', + 'contri*', + ]: + return + + prem_args = set(sum([p.args for p in self.premise], [])) + con_args = set(con.args) + if len(prem_args) <= len(con_args): + self.is_arg_reduce = True + + def txt(self) -> str: + premise_txt = ', '.join([clause.txt() for clause in self.premise]) + conclusion_txt = ', '.join([clause.txt() for clause in self.conclusion]) + return f'{premise_txt} => {conclusion_txt}' + + def conclusion_name_args( + self, mapping: dict[str, gm.Point] + ) -> tuple[str, list[gm.Point]]: + mapping = {arg: p for arg, p in mapping.items() if isinstance(arg, str)} + c = self.conclusion[0] + args = [mapping[a] for a in c.args] + return c.name, args + + +def why_eqratio( + d1: gm.Direction, + d2: gm.Direction, + d3: gm.Direction, + d4: gm.Direction, + level: int, +) -> list[Dependency]: + """Why two ratios are equal, returns a Dependency objects.""" + all12 = list(gm.all_ratios(d1, d2, level)) + all34 = list(gm.all_ratios(d3, d4, level)) + + min_why = None + for ang12, d1s, d2s in all12: + for ang34, d3s, d4s in all34: + why0 = gm.why_equal(ang12, ang34, level) + if why0 is None: + continue + d1_, d2_ = ang12._l + d3_, d4_ = ang34._l + why1 = gm.bfs_backtrack(d1, [d1_], d1s) + why2 = gm.bfs_backtrack(d2, [d2_], d2s) + why3 = gm.bfs_backtrack(d3, [d3_], d3s) + why4 = gm.bfs_backtrack(d4, [d4_], d4s) + why = why0 + why1 + why2 + why3 + why4 + if min_why is None or len(why) < len(min_why[0]): + min_why = why, ang12, ang34, why0, why1, why2, why3, why4 + + if min_why is None: + return None + + _, ang12, ang34, why0, why1, why2, why3, why4 = min_why + d1_, d2_ = ang12._l + d3_, d4_ = ang34._l + + if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_: + return why0 + + (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points + (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points + deps = [] + if why0: + dep = Dependency('eqratio', [a_, b_, c_, d_, e_, f_, g_, h_], '', level) + dep.why = why0 + deps.append(dep) + + (a, b), (c, d) = d1._obj.points, d2._obj.points + (e, f), (g, h) = d3._obj.points, d4._obj.points + for why, (x, y), (x_, y_) in zip( + [why1, why2, why3, why4], + [(a, b), (c, d), (e, f), (g, h)], + [(a_, b_), (c_, d_), (e_, f_), (g_, h_)], + ): + if why: + dep = Dependency('cong', [x, y, x_, y_], '', level) + dep.why = why + deps.append(dep) + + return deps + + +def why_eqangle( + d1: gm.Direction, + d2: gm.Direction, + d3: gm.Direction, + d4: gm.Direction, + level: int, + verbose: bool = False, +) -> list[Dependency]: + """Why two angles are equal, returns a Dependency objects.""" + all12 = list(gm.all_angles(d1, d2, level)) + all34 = list(gm.all_angles(d3, d4, level)) + + min_why = None + for ang12, d1s, d2s in all12: + for ang34, d3s, d4s in all34: + why0 = gm.why_equal(ang12, ang34, level) + if why0 is None: + continue + d1_, d2_ = ang12._d + d3_, d4_ = ang34._d + why1 = gm.bfs_backtrack(d1, [d1_], d1s) + why2 = gm.bfs_backtrack(d2, [d2_], d2s) + why3 = gm.bfs_backtrack(d3, [d3_], d3s) + why4 = gm.bfs_backtrack(d4, [d4_], d4s) + why = why0 + why1 + why2 + why3 + why4 + if min_why is None or len(why) < len(min_why[0]): + min_why = why, ang12, ang34, why0, why1, why2, why3, why4 + + if min_why is None: + return None + + _, ang12, ang34, why0, why1, why2, why3, why4 = min_why + why0 = gm.why_equal(ang12, ang34, level) + d1_, d2_ = ang12._d + d3_, d4_ = ang34._d + + if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_: + return (d1_, d2_, d3_, d4_), why0 + + (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points + (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points + deps = [] + if why0: + dep = Dependency('eqangle', [a_, b_, c_, d_, e_, f_, g_, h_], '', None) + dep.why = why0 + deps.append(dep) + + (a, b), (c, d) = d1._obj.points, d2._obj.points + (e, f), (g, h) = d3._obj.points, d4._obj.points + for why, d_xy, (x, y), d_xy_, (x_, y_) in zip( + [why1, why2, why3, why4], + [d1, d2, d3, d4], + [(a, b), (c, d), (e, f), (g, h)], + [d1_, d2_, d3_, d4_], + [(a_, b_), (c_, d_), (e_, f_), (g_, h_)], + ): + xy, xy_ = d_xy._obj, d_xy_._obj + if why: + if xy == xy_: + name = 'collx' + else: + name = 'para' + dep = Dependency(name, [x_, y_, x, y], '', None) + dep.why = why + deps.append(dep) + + return (d1_, d2_, d3_, d4_), deps + + +CONSTRUCTION_RULE = 'c0' + + +class EmptyDependency: + """Empty dependency predicate ready to get filled up.""" + + def __init__(self, level: int, rule_name: str): + self.level = level + self.rule_name = rule_name or '' + self.empty = True + self.why = [] + self.trace = None + + def populate(self, name: str, args: list[gm.Point]) -> Dependency: + dep = Dependency(name, args, self.rule_name, self.level) + dep.trace2 = self.trace + dep.why = list(self.why) + return dep + + def copy(self) -> EmptyDependency: + other = EmptyDependency(self.level, self.rule_name) + other.why = list(self.why) + return other + + def extend( + self, + g: Any, + name0: str, + args0: list[gm.Point], + name: str, + args: list[gm.Point], + ) -> EmptyDependency: + """Extend the dependency list by (name, args).""" + dep0 = self.populate(name0, args0) + deps = EmptyDependency(level=self.level, rule_name=None) + dep = Dependency(name, args, None, deps.level) + deps.why = [dep0, dep.why_me_or_cache(g, None)] + return deps + + def extend_many( + self, + g: Any, + name0: str, + args0: list[gm.Point], + name_args: list[tuple[str, list[gm.Point]]], + ) -> EmptyDependency: + """Extend the dependency list by many name_args.""" + if not name_args: + return self + dep0 = self.populate(name0, args0) + deps = EmptyDependency(level=self.level, rule_name=None) + deps.why = [dep0] + for name, args in name_args: + dep = Dependency(name, args, None, deps.level) + deps.why += [dep.why_me_or_cache(g, None)] + return deps + + +def maybe_make_equal_pairs( + a: gm.Point, + b: gm.Point, + c: gm.Point, + d: gm.Point, + m: gm.Point, + n: gm.Point, + p: gm.Point, + q: gm.Point, + ab: gm.Line, + mn: gm.Line, + g: Any, + level: int, +) -> list[Dependency]: + """Make a-b:c-d==m-n:p-q in case a-b==m-n or c-d==p-q.""" + if ab != mn: + return + why = [] + eqname = 'para' if isinstance(ab, gm.Line) else 'cong' + colls = [a, b, m, n] + if len(set(colls)) > 2 and eqname == 'para': + dep = Dependency('collx', colls, None, level) + dep.why_me(g, level) + why += [dep] + + dep = Dependency(eqname, [c, d, p, q], None, level) + dep.why_me(g, level) + why += [dep] + return why + + +class Dependency(Construction): + """Dependency is a predicate that other predicates depend on.""" + + def __init__( + self, name: str, args: list[gm.Point], rule_name: str, level: int + ): + super().__init__(name, args) + self.rule_name = rule_name or '' + self.level = level + self.why = [] + + self._stat = None + self.trace = None + + def _find(self, dep_hashed: tuple[str, ...]) -> Dependency: + for w in self.why: + f = w._find(dep_hashed) + if f: + return f + if w.hashed() == dep_hashed: + return w + + def remove_loop(self) -> Dependency: + f = self._find(self.hashed()) + if f: + return f + return self + + def copy(self) -> Dependency: + dep = Dependency(self.name, self.args, self.rule_name, self.level) + dep.trace = self.trace + dep.why = list(self.why) + return dep + + def why_me_or_cache(self, g: Any, level: int) -> Dependency: + if self.hashed() in g.cache: + return g.cache[self.hashed()] + self.why_me(g, level) + return self + + def populate(self, name: str, args: list[gm.Point]) -> Dependency: + assert self.rule_name == CONSTRUCTION_RULE, self.rule_name + dep = Dependency(self.name, self.args, self.rule_name, self.level) + dep.why = list(self.why) + return dep + + def why_me(self, g: Any, level: int) -> None: + """Figure out the dependencies predicates of self.""" + name, args = self.name, self.args + + hashed_me = hashed(name, args) + if hashed_me in g.cache: + dep = g.cache[hashed_me] + self.why = dep.why + self.rule_name = dep.rule_name + return + + if self.name == 'para': + a, b, c, d = self.args + if {a, b} == {c, d}: + self.why = [] + return + + ab = g._get_line(a, b) + cd = g._get_line(c, d) + if ab == cd: + if {a, b} == {c, d}: + self.why = [] + self.rule_name = '' + return + dep = Dependency('coll', list({a, b, c, d}), 't??', None) + self.why = [dep.why_me_or_cache(g, level)] + return + + for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]): + x_, y_ = xy.points + if {x, y} == {x_, y_}: + continue + d = Dependency('collx', [x, y, x_, y_], None, level) + self.why += [d.why_me_or_cache(g, level)] + + whypara = g.why_equal(ab, cd, None) + self.why += whypara + + elif self.name == 'midp': + m, a, b = self.args + ma = g._get_segment(m, a) + mb = g._get_segment(m, b) + dep = Dependency('coll', [m, a, b], None, None).why_me_or_cache(g, None) + self.why = [dep] + g.why_equal(ma, mb, level) + + elif self.name == 'perp': + a, b, c, d = self.args + ab = g._get_line(a, b) + cd = g._get_line(c, d) + for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]): + x_, y_ = xy.points + if {x, y} == {x_, y_}: + continue + d = Dependency('collx', [x, y, x_, y_], None, level) + self.why += [d.why_me_or_cache(g, level)] + + _, why = why_eqangle(ab._val, cd._val, cd._val, ab._val, level) + a, b = ab.points + c, d = cd.points + + if hashed(self.name, [a, b, c, d]) != self.hashed(): + d = Dependency(self.name, [a, b, c, d], None, level) + d.why = why + why = [d] + + self.why += why + + elif self.name == 'cong': + a, b, c, d = self.args + ab = g._get_segment(a, b) + cd = g._get_segment(c, d) + + self.why = g.why_equal(ab, cd, level) + + elif self.name == 'coll': + _, why = gm.line_of_and_why(self.args, level) + self.why = why + + elif self.name == 'collx': + if g.check_coll(self.args): + args = list(set(self.args)) + hashed_me = hashed('coll', args) + if hashed_me in g.cache: + dep = g.cache[hashed_me] + self.why = [dep] + self.rule_name = '' + return + _, self.why = gm.line_of_and_why(args, level) + else: + self.name = 'para' + self.why_me(g, level) + + elif self.name == 'cyclic': + _, why = gm.circle_of_and_why(self.args, level) + self.why = why + + elif self.name == 'circle': + o, a, b, c = self.args + oa = g._get_segment(o, a) + ob = g._get_segment(o, b) + oc = g._get_segment(o, c) + self.why = g.why_equal(oa, ob, level) + g.why_equal(oa, oc, level) + + elif self.name in ['eqangle', 'eqangle6']: + a, b, c, d, m, n, p, q = self.args + + ab, why1 = g.get_line_thru_pair_why(a, b) + cd, why2 = g.get_line_thru_pair_why(c, d) + mn, why3 = g.get_line_thru_pair_why(m, n) + pq, why4 = g.get_line_thru_pair_why(p, q) + + if ab is None or cd is None or mn is None or pq is None: + if {a, b} == {m, n}: + d = Dependency('para', [c, d, p, q], None, level) + self.why = [d.why_me_or_cache(g, level)] + if {a, b} == {c, d}: + d = Dependency('para', [p, q, m, n], None, level) + self.why = [d.why_me_or_cache(g, level)] + if {c, d} == {p, q}: + d = Dependency('para', [a, b, m, n], None, level) + self.why = [d.why_me_or_cache(g, level)] + if {p, q} == {m, n}: + d = Dependency('para', [a, b, c, d], None, level) + self.why = [d.why_me_or_cache(g, level)] + return + + for (x, y), xy, whyxy in zip( + [(a, b), (c, d), (m, n), (p, q)], + [ab, cd, mn, pq], + [why1, why2, why3, why4], + ): + x_, y_ = xy.points + if {x, y} == {x_, y_}: + continue + d = Dependency('collx', [x, y, x_, y_], None, level) + d.why = whyxy + self.why += [d] + + a, b = ab.points + c, d = cd.points + m, n = mn.points + p, q = pq.points + diff = hashed(self.name, [a, b, c, d, m, n, p, q]) != self.hashed() + + whyeqangle = None + if ab._val and cd._val and mn._val and pq._val: + whyeqangle = why_eqangle(ab._val, cd._val, mn._val, pq._val, level) + + if whyeqangle: + (dab, dcd, dmn, dpq), whyeqangle = whyeqangle + if diff: + d = Dependency('eqangle', [a, b, c, d, m, n, p, q], None, level) + d.why = whyeqangle + whyeqangle = [d] + self.why += whyeqangle + + else: + if (ab == cd and mn == pq) or (ab == mn and cd == pq): + self.why += [] + elif ab == mn: + self.why += maybe_make_equal_pairs( + a, b, c, d, m, n, p, q, ab, mn, g, level + ) + elif cd == pq: + self.why += maybe_make_equal_pairs( + c, d, a, b, p, q, m, n, cd, pq, g, level + ) + elif ab == cd: + self.why += maybe_make_equal_pairs( + a, b, m, n, c, d, p, q, ab, cd, g, level + ) + elif mn == pq: + self.why += maybe_make_equal_pairs( + m, n, a, b, p, q, c, d, mn, pq, g, level + ) + elif g.is_equal(ab, mn) or g.is_equal(cd, pq): + dep1 = Dependency('para', [a, b, m, n], None, level) + dep1.why_me(g, level) + dep2 = Dependency('para', [c, d, p, q], None, level) + dep2.why_me(g, level) + self.why += [dep1, dep2] + elif g.is_equal(ab, cd) or g.is_equal(mn, pq): + dep1 = Dependency('para', [a, b, c, d], None, level) + dep1.why_me(g, level) + dep2 = Dependency('para', [m, n, p, q], None, level) + dep2.why_me(g, level) + self.why += [dep1, dep2] + elif ab._val and cd._val and mn._val and pq._val: + self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level) + + elif self.name in ['eqratio', 'eqratio6']: + a, b, c, d, m, n, p, q = self.args + ab = g._get_segment(a, b) + cd = g._get_segment(c, d) + mn = g._get_segment(m, n) + pq = g._get_segment(p, q) + + if ab is None or cd is None or mn is None or pq is None: + if {a, b} == {m, n}: + d = Dependency('cong', [c, d, p, q], None, level) + self.why = [d.why_me_or_cache(g, level)] + if {a, b} == {c, d}: + d = Dependency('cong', [p, q, m, n], None, level) + self.why = [d.why_me_or_cache(g, level)] + if {c, d} == {p, q}: + d = Dependency('cong', [a, b, m, n], None, level) + self.why = [d.why_me_or_cache(g, level)] + if {p, q} == {m, n}: + d = Dependency('cong', [a, b, c, d], None, level) + self.why = [d.why_me_or_cache(g, level)] + return + + if ab._val and cd._val and mn._val and pq._val: + self.why = why_eqratio(ab._val, cd._val, mn._val, pq._val, level) + + if self.why is None: + self.why = [] + if (ab == cd and mn == pq) or (ab == mn and cd == pq): + self.why = [] + elif ab == mn: + self.why += maybe_make_equal_pairs( + a, b, c, d, m, n, p, q, ab, mn, g, level + ) + elif cd == pq: + self.why += maybe_make_equal_pairs( + c, d, a, b, p, q, m, n, cd, pq, g, level + ) + elif ab == cd: + self.why += maybe_make_equal_pairs( + a, b, m, n, c, d, p, q, ab, cd, g, level + ) + elif mn == pq: + self.why += maybe_make_equal_pairs( + m, n, a, b, p, q, c, d, mn, pq, g, level + ) + elif g.is_equal(ab, mn) or g.is_equal(cd, pq): + dep1 = Dependency('cong', [a, b, m, n], None, level) + dep1.why_me(g, level) + dep2 = Dependency('cong', [c, d, p, q], None, level) + dep2.why_me(g, level) + self.why += [dep1, dep2] + elif g.is_equal(ab, cd) or g.is_equal(mn, pq): + dep1 = Dependency('cong', [a, b, c, d], None, level) + dep1.why_me(g, level) + dep2 = Dependency('cong', [m, n, p, q], None, level) + dep2.why_me(g, level) + self.why += [dep1, dep2] + elif ab._val and cd._val and mn._val and pq._val: + self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level) + + elif self.name in ['diff', 'npara', 'nperp', 'ncoll', 'sameside']: + self.why = [] + + elif self.name == 'simtri': + a, b, c, x, y, z = self.args + dep1 = Dependency('eqangle', [a, b, a, c, x, y, x, z], '', level) + dep1.why_me(g, level) + dep2 = Dependency('eqangle', [b, a, b, c, y, x, y, z], '', level) + dep2.why_me(g, level) + self.rule_name = 'r34' + self.why = [dep1, dep2] + + elif self.name == 'contri': + a, b, c, x, y, z = self.args + dep1 = Dependency('cong', [a, b, x, y], '', level) + dep1.why_me(g, level) + dep2 = Dependency('cong', [b, c, y, z], '', level) + dep2.why_me(g, level) + dep3 = Dependency('cong', [c, a, z, x], '', level) + dep3.why_me(g, level) + self.rule_name = 'r32' + self.why = [dep1, dep2, dep3] + + elif self.name == 'ind': + pass + + elif self.name == 'aconst': + a, b, c, d, ang0 = self.args + + measure = ang0._val + + for ang in measure.neighbors(gm.Angle): + if ang == ang0: + continue + d1, d2 = ang._d + l1, l2 = d1._obj, d2._obj + (a1, b1), (c1, d1) = l1.points, l2.points + + if not g.check_para_or_coll([a, b, a1, b1]) or not g.check_para_or_coll( + [c, d, c1, d1] + ): + continue + + self.why = [] + for args in [(a, b, a1, b1), (c, d, c1, d1)]: + if g.check_coll(args): + if len(set(args)) > 2: + dep = Dependency('coll', args, None, None) + self.why.append(dep.why_me_or_cache(g, level)) + else: + dep = Dependency('para', args, None, None) + self.why.append(dep.why_me_or_cache(g, level)) + + self.why += gm.why_equal(ang, ang0) + break + + elif self.name == 'rconst': + a, b, c, d, rat0 = self.args + + val = rat0._val + + for rat in val.neighbors(gm.Ratio): + if rat == rat0: + continue + l1, l2 = rat._l + s1, s2 = l1._obj, l2._obj + (a1, b1), (c1, d1) = list(s1.points), list(s2.points) + + if not g.check_cong([a, b, a1, b1]) or not g.check_cong([c, d, c1, d1]): + continue + + self.why = [] + for args in [(a, b, a1, b1), (c, d, c1, d1)]: + if len(set(args)) > 2: + dep = Dependency('cong', args, None, None) + self.why.append(dep.why_me_or_cache(g, level)) + + self.why += gm.why_equal(rat, rat0) + break + + else: + raise ValueError('Not recognize', self.name) + + def hashed(self, rename: bool = False) -> tuple[str, ...]: + return hashed(self.name, self.args, rename=rename) + + +def hashed( + name: str, args: list[gm.Point], rename: bool = False +) -> tuple[str, ...]: + if name == 's_angle': + args = [p.name if not rename else p.new_name for p in args[:-1]] + [ + str(args[-1]) + ] + else: + args = [p.name if not rename else p.new_name for p in args] + return hashed_txt(name, args) + + +def hashed_txt(name: str, args: list[str]) -> tuple[str, ...]: + """Return a tuple unique to name and args upto arg permutation equivariant.""" + + if name in ['const', 'aconst', 'rconst']: + a, b, c, d, y = args + a, b = sorted([a, b]) + c, d = sorted([c, d]) + return name, a, b, c, d, y + + if name in ['npara', 'nperp', 'para', 'cong', 'perp', 'collx']: + a, b, c, d = args + + a, b = sorted([a, b]) + c, d = sorted([c, d]) + (a, b), (c, d) = sorted([(a, b), (c, d)]) + + return (name, a, b, c, d) + + if name in ['midp', 'midpoint']: + a, b, c = args + b, c = sorted([b, c]) + return (name, a, b, c) + + if name in ['coll', 'cyclic', 'ncoll', 'diff', 'triangle']: + return (name,) + tuple(sorted(list(set(args)))) + + if name == 'circle': + x, a, b, c = args + return (name, x) + tuple(sorted([a, b, c])) + + if name in ['eqangle', 'eqratio', 'eqangle6', 'eqratio6']: + a, b, c, d, e, f, g, h = args + a, b = sorted([a, b]) + c, d = sorted([c, d]) + e, f = sorted([e, f]) + g, h = sorted([g, h]) + if tuple(sorted([a, b, e, f])) > tuple(sorted([c, d, g, h])): + a, b, e, f, c, d, g, h = c, d, g, h, a, b, e, f + if (a, b, c, d) > (e, f, g, h): + a, b, c, d, e, f, g, h = e, f, g, h, a, b, c, d + + if name == 'eqangle6': + name = 'eqangle' + if name == 'eqratio6': + name = 'eqratio' + return (name,) + (a, b, c, d, e, f, g, h) + + if name in ['contri', 'simtri', 'simtri2', 'contri2', 'contri*', 'simtri*']: + a, b, c, x, y, z = args + (a, x), (b, y), (c, z) = sorted([(a, x), (b, y), (c, z)], key=sorted) + (a, b, c), (x, y, z) = sorted([(a, b, c), (x, y, z)], key=sorted) + return (name, a, b, c, x, y, z) + + if name in ['eqratio3']: + a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name + (a, c), (b, d) = sorted([(a, c), (b, d)], key=sorted) + (a, b), (c, d) = sorted([(a, b), (c, d)], key=sorted) + return (name, a, b, c, d, o, o) + + if name in ['sameside', 's_angle']: + return (name,) + tuple(args) + + raise ValueError(f'Not recognize {name} to hash.') diff --git a/problem_test.py b/problem_test.py new file mode 100644 index 0000000..b8fe3ba --- /dev/null +++ b/problem_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for problem.py.""" +import unittest + +from absl.testing import absltest +import problem as pr + + +class ProblemTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) + + def test_orthocenter_no_translate(self): + txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long + + # read the txt into pr.Problem object, do not change the name of points: + p = pr.Problem.from_txt(txt, translate=False) + + # This is fed into the LM, translating from constructive to constrained: + setup_str = p.setup_str_from_problem(ProblemTest.defs) + + self.assertEqual( + setup_str, + '{S} a : ; b : ; c : ; h : T a b c h 00 T a c b h 01 ? T a h b c', + ) + + def test_orthocenter_translate(self): + txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long + + # Read the txt into pr.Problem object, change h -> d to match + # training data distribution. + p = pr.Problem.from_txt(txt, translate=True) + + # This is fed into the LM, translating from constructive to constrained: + setup_str = p.setup_str_from_problem(ProblemTest.defs) + + self.assertEqual( + setup_str, + '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c', + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000..876b2a9 --- /dev/null +++ b/requirements.in @@ -0,0 +1,17 @@ +tensorflow==2.13.0 +numpy==1.23.5 +scipy==1.10.0 +matplotlib==3.7.0 +gdown==4.7.1 +jax==0.4.6 +jaxlib==0.4.6 +flax==0.5.3 +gin-config==0.5.0 +gin==0.1.6 +t5==0.9.4 +sentencepiece==0.1.99 +absl-py==1.4.0 +clu==0.0.7 +optax==0.1.7 +seqio==0.0.18 +tensorflow-datasets==4.9.3 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..36fd376 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2068 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --generate-hashes --resolver=backtracking requirements.in +# +absl-py==1.4.0 \ + --hash=sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47 \ + --hash=sha256:d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d + # via + # -r requirements.in + # array-record + # chex + # clu + # etils + # mesh-tensorflow + # ml-collections + # optax + # rouge-score + # seqio + # seqio-nightly + # t5 + # tensorboard + # tensorflow + # tensorflow-datasets + # tensorflow-metadata + # tfds-nightly +array-record==0.4.1 \ + --hash=sha256:6a0c8ed6fdfaaf2cecd3d5c6b9c13e116ad3299649611c8fd184d64557fbaba8 \ + --hash=sha256:a74e9c0075860777b79e4b3ac278f67add270acf78520d3b9cf8c325aef42951 + # via + # tensorflow-datasets + # tfds-nightly +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via tensorflow +babel==2.13.0 \ + --hash=sha256:04c3e2d28d2b7681644508f836be388ae49e0cfe91465095340395b60d00f210 \ + --hash=sha256:fbfcae1575ff78e26c7449136f1abbefc3c13ce542eeb13d43d50d8b047216ec + # via t5 +beautifulsoup4==4.12.2 \ + --hash=sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da \ + --hash=sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a + # via gdown +cached-property==1.5.2 \ + --hash=sha256:9fa5755838eecbb2d234c3aa390bd80fbd3ac6b6869109bfc1b499f7bd89a130 \ + --hash=sha256:df4f613cf7ad9a588cc381aaf4a512d26265ecebd5eb9e1ba12f1319eb85a6a0 + # via clu +cachetools==5.3.1 \ + --hash=sha256:95ef631eeaea14ba2e36f06437f36463aac3a096799e876ee55e5cdccb102590 \ + --hash=sha256:dce83f2d9b4e1f732a8cd44af8e8fab2dbe46201467fc98b3ef8f269092bf62b + # via google-auth +certifi==2023.7.22 \ + --hash=sha256:539cc1d13202e33ca466e88b2807e29f4c13049d6d87031a3c110744495cb082 \ + --hash=sha256:92d6037539857d8206b8f6ae472e8b77db8058fec5937a1ef3f54304089edbb9 + # via requests +charset-normalizer==3.3.0 \ + --hash=sha256:02673e456dc5ab13659f85196c534dc596d4ef260e4d86e856c3b2773ce09843 \ + --hash=sha256:02af06682e3590ab952599fbadac535ede5d60d78848e555aa58d0c0abbde786 \ + --hash=sha256:03680bb39035fbcffe828eae9c3f8afc0428c91d38e7d61aa992ef7a59fb120e \ + --hash=sha256:0570d21da019941634a531444364f2482e8db0b3425fcd5ac0c36565a64142c8 \ + --hash=sha256:09c77f964f351a7369cc343911e0df63e762e42bac24cd7d18525961c81754f4 \ + --hash=sha256:0d3d5b7db9ed8a2b11a774db2bbea7ba1884430a205dbd54a32d61d7c2a190fa \ + --hash=sha256:1063da2c85b95f2d1a430f1c33b55c9c17ffaf5e612e10aeaad641c55a9e2b9d \ + --hash=sha256:12ebea541c44fdc88ccb794a13fe861cc5e35d64ed689513a5c03d05b53b7c82 \ + --hash=sha256:153e7b6e724761741e0974fc4dcd406d35ba70b92bfe3fedcb497226c93b9da7 \ + --hash=sha256:15b26ddf78d57f1d143bdf32e820fd8935d36abe8a25eb9ec0b5a71c82eb3895 \ + --hash=sha256:1872d01ac8c618a8da634e232f24793883d6e456a66593135aeafe3784b0848d \ + --hash=sha256:187d18082694a29005ba2944c882344b6748d5be69e3a89bf3cc9d878e548d5a \ + --hash=sha256:1b2919306936ac6efb3aed1fbf81039f7087ddadb3160882a57ee2ff74fd2382 \ + --hash=sha256:232ac332403e37e4a03d209a3f92ed9071f7d3dbda70e2a5e9cff1c4ba9f0678 \ + --hash=sha256:23e8565ab7ff33218530bc817922fae827420f143479b753104ab801145b1d5b \ + --hash=sha256:24817cb02cbef7cd499f7c9a2735286b4782bd47a5b3516a0e84c50eab44b98e \ + --hash=sha256:249c6470a2b60935bafd1d1d13cd613f8cd8388d53461c67397ee6a0f5dce741 \ + --hash=sha256:24a91a981f185721542a0b7c92e9054b7ab4fea0508a795846bc5b0abf8118d4 \ + --hash=sha256:2502dd2a736c879c0f0d3e2161e74d9907231e25d35794584b1ca5284e43f596 \ + --hash=sha256:250c9eb0f4600361dd80d46112213dff2286231d92d3e52af1e5a6083d10cad9 \ + --hash=sha256:278c296c6f96fa686d74eb449ea1697f3c03dc28b75f873b65b5201806346a69 \ + --hash=sha256:2935ffc78db9645cb2086c2f8f4cfd23d9b73cc0dc80334bc30aac6f03f68f8c \ + --hash=sha256:2f4a0033ce9a76e391542c182f0d48d084855b5fcba5010f707c8e8c34663d77 \ + --hash=sha256:30a85aed0b864ac88309b7d94be09f6046c834ef60762a8833b660139cfbad13 \ + --hash=sha256:380c4bde80bce25c6e4f77b19386f5ec9db230df9f2f2ac1e5ad7af2caa70459 \ + --hash=sha256:3ae38d325b512f63f8da31f826e6cb6c367336f95e418137286ba362925c877e \ + --hash=sha256:3b447982ad46348c02cb90d230b75ac34e9886273df3a93eec0539308a6296d7 \ + --hash=sha256:3debd1150027933210c2fc321527c2299118aa929c2f5a0a80ab6953e3bd1908 \ + --hash=sha256:4162918ef3098851fcd8a628bf9b6a98d10c380725df9e04caf5ca6dd48c847a \ + --hash=sha256:468d2a840567b13a590e67dd276c570f8de00ed767ecc611994c301d0f8c014f \ + --hash=sha256:4cc152c5dd831641e995764f9f0b6589519f6f5123258ccaca8c6d34572fefa8 \ + --hash=sha256:542da1178c1c6af8873e143910e2269add130a299c9106eef2594e15dae5e482 \ + --hash=sha256:557b21a44ceac6c6b9773bc65aa1b4cc3e248a5ad2f5b914b91579a32e22204d \ + --hash=sha256:5707a746c6083a3a74b46b3a631d78d129edab06195a92a8ece755aac25a3f3d \ + --hash=sha256:588245972aca710b5b68802c8cad9edaa98589b1b42ad2b53accd6910dad3545 \ + --hash=sha256:5adf257bd58c1b8632046bbe43ee38c04e1038e9d37de9c57a94d6bd6ce5da34 \ + --hash=sha256:619d1c96099be5823db34fe89e2582b336b5b074a7f47f819d6b3a57ff7bdb86 \ + --hash=sha256:63563193aec44bce707e0c5ca64ff69fa72ed7cf34ce6e11d5127555756fd2f6 \ + --hash=sha256:67b8cc9574bb518ec76dc8e705d4c39ae78bb96237cb533edac149352c1f39fe \ + --hash=sha256:6a685067d05e46641d5d1623d7c7fdf15a357546cbb2f71b0ebde91b175ffc3e \ + --hash=sha256:70f1d09c0d7748b73290b29219e854b3207aea922f839437870d8cc2168e31cc \ + --hash=sha256:750b446b2ffce1739e8578576092179160f6d26bd5e23eb1789c4d64d5af7dc7 \ + --hash=sha256:7966951325782121e67c81299a031f4c115615e68046f79b85856b86ebffc4cd \ + --hash=sha256:7b8b8bf1189b3ba9b8de5c8db4d541b406611a71a955bbbd7385bbc45fcb786c \ + --hash=sha256:7f5d10bae5d78e4551b7be7a9b29643a95aded9d0f602aa2ba584f0388e7a557 \ + --hash=sha256:805dfea4ca10411a5296bcc75638017215a93ffb584c9e344731eef0dcfb026a \ + --hash=sha256:81bf654678e575403736b85ba3a7867e31c2c30a69bc57fe88e3ace52fb17b89 \ + --hash=sha256:82eb849f085624f6a607538ee7b83a6d8126df6d2f7d3b319cb837b289123078 \ + --hash=sha256:85a32721ddde63c9df9ebb0d2045b9691d9750cb139c161c80e500d210f5e26e \ + --hash=sha256:86d1f65ac145e2c9ed71d8ffb1905e9bba3a91ae29ba55b4c46ae6fc31d7c0d4 \ + --hash=sha256:86f63face3a527284f7bb8a9d4f78988e3c06823f7bea2bd6f0e0e9298ca0403 \ + --hash=sha256:8eaf82f0eccd1505cf39a45a6bd0a8cf1c70dcfc30dba338207a969d91b965c0 \ + --hash=sha256:93aa7eef6ee71c629b51ef873991d6911b906d7312c6e8e99790c0f33c576f89 \ + --hash=sha256:96c2b49eb6a72c0e4991d62406e365d87067ca14c1a729a870d22354e6f68115 \ + --hash=sha256:9cf3126b85822c4e53aa28c7ec9869b924d6fcfb76e77a45c44b83d91afd74f9 \ + --hash=sha256:9fe359b2e3a7729010060fbca442ca225280c16e923b37db0e955ac2a2b72a05 \ + --hash=sha256:a0ac5e7015a5920cfce654c06618ec40c33e12801711da6b4258af59a8eff00a \ + --hash=sha256:a3f93dab657839dfa61025056606600a11d0b696d79386f974e459a3fbc568ec \ + --hash=sha256:a4b71f4d1765639372a3b32d2638197f5cd5221b19531f9245fcc9ee62d38f56 \ + --hash=sha256:aae32c93e0f64469f74ccc730a7cb21c7610af3a775157e50bbd38f816536b38 \ + --hash=sha256:aaf7b34c5bc56b38c931a54f7952f1ff0ae77a2e82496583b247f7c969eb1479 \ + --hash=sha256:abecce40dfebbfa6abf8e324e1860092eeca6f7375c8c4e655a8afb61af58f2c \ + --hash=sha256:abf0d9f45ea5fb95051c8bfe43cb40cda383772f7e5023a83cc481ca2604d74e \ + --hash=sha256:ac71b2977fb90c35d41c9453116e283fac47bb9096ad917b8819ca8b943abecd \ + --hash=sha256:ada214c6fa40f8d800e575de6b91a40d0548139e5dc457d2ebb61470abf50186 \ + --hash=sha256:b09719a17a2301178fac4470d54b1680b18a5048b481cb8890e1ef820cb80455 \ + --hash=sha256:b1121de0e9d6e6ca08289583d7491e7fcb18a439305b34a30b20d8215922d43c \ + --hash=sha256:b3b2316b25644b23b54a6f6401074cebcecd1244c0b8e80111c9a3f1c8e83d65 \ + --hash=sha256:b3d9b48ee6e3967b7901c052b670c7dda6deb812c309439adaffdec55c6d7b78 \ + --hash=sha256:b5bcf60a228acae568e9911f410f9d9e0d43197d030ae5799e20dca8df588287 \ + --hash=sha256:b8f3307af845803fb0b060ab76cf6dd3a13adc15b6b451f54281d25911eb92df \ + --hash=sha256:c2af80fb58f0f24b3f3adcb9148e6203fa67dd3f61c4af146ecad033024dde43 \ + --hash=sha256:c350354efb159b8767a6244c166f66e67506e06c8924ed74669b2c70bc8735b1 \ + --hash=sha256:c5a74c359b2d47d26cdbbc7845e9662d6b08a1e915eb015d044729e92e7050b7 \ + --hash=sha256:c71f16da1ed8949774ef79f4a0260d28b83b3a50c6576f8f4f0288d109777989 \ + --hash=sha256:d47ecf253780c90ee181d4d871cd655a789da937454045b17b5798da9393901a \ + --hash=sha256:d7eff0f27edc5afa9e405f7165f85a6d782d308f3b6b9d96016c010597958e63 \ + --hash=sha256:d97d85fa63f315a8bdaba2af9a6a686e0eceab77b3089af45133252618e70884 \ + --hash=sha256:db756e48f9c5c607b5e33dd36b1d5872d0422e960145b08ab0ec7fd420e9d649 \ + --hash=sha256:dc45229747b67ffc441b3de2f3ae5e62877a282ea828a5bdb67883c4ee4a8810 \ + --hash=sha256:e0fc42822278451bc13a2e8626cf2218ba570f27856b536e00cfa53099724828 \ + --hash=sha256:e39c7eb31e3f5b1f88caff88bcff1b7f8334975b46f6ac6e9fc725d829bc35d4 \ + --hash=sha256:e46cd37076971c1040fc8c41273a8b3e2c624ce4f2be3f5dfcb7a430c1d3acc2 \ + --hash=sha256:e5c1502d4ace69a179305abb3f0bb6141cbe4714bc9b31d427329a95acfc8bdd \ + --hash=sha256:edfe077ab09442d4ef3c52cb1f9dab89bff02f4524afc0acf2d46be17dc479f5 \ + --hash=sha256:effe5406c9bd748a871dbcaf3ac69167c38d72db8c9baf3ff954c344f31c4cbe \ + --hash=sha256:f0d1e3732768fecb052d90d62b220af62ead5748ac51ef61e7b32c266cac9293 \ + --hash=sha256:f5969baeaea61c97efa706b9b107dcba02784b1601c74ac84f2a532ea079403e \ + --hash=sha256:f8888e31e3a85943743f8fc15e71536bda1c81d5aa36d014a3c0c44481d7db6e \ + --hash=sha256:fc52b79d83a3fe3a360902d3f5d79073a993597d48114c29485e9431092905d8 + # via requests +chex==0.1.7 \ + --hash=sha256:74ed49799ac4d229881456d468136f1b19a9f9839e3de72b058824e2a4f4dedd \ + --hash=sha256:9f583015303b1205443843c0b55849bb287f1dfdbd22d9907b1ebb04f964d93e + # via optax +click==8.1.7 \ + --hash=sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28 \ + --hash=sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de + # via + # nltk + # tensorflow-datasets + # tfds-nightly +clu==0.0.7 \ + --hash=sha256:028cb27b7c1b32a59e969477d3ffe9d6b3e80d4c860ac1195cc9e46eacb0605f \ + --hash=sha256:449a9af179c3a5c44fc18947fb844a00ea2cd31e3ede13167ff6efe89367db2b + # via + # -r requirements.in + # seqio + # seqio-nightly +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 + # via + # rich + # sacrebleu +commonmark==0.9.1 \ + --hash=sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60 \ + --hash=sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9 + # via rich +contextlib2==21.6.0 \ + --hash=sha256:3fbdb64466afd23abaf6c977627b75b6139a5a3e8ce38405c5b413aed7a0471f \ + --hash=sha256:ab1e2bfe1d01d968e1b7e8d9023bc51ef3509bba217bb730cee3827e1ee82869 + # via ml-collections +contourpy==1.1.1 \ + --hash=sha256:059c3d2a94b930f4dafe8105bcdc1b21de99b30b51b5bce74c753686de858cb6 \ + --hash=sha256:0683e1ae20dc038075d92e0e0148f09ffcefab120e57f6b4c9c0f477ec171f33 \ + --hash=sha256:07d6f11dfaf80a84c97f1a5ba50d129d9303c5b4206f776e94037332e298dda8 \ + --hash=sha256:081f3c0880712e40effc5f4c3b08feca6d064cb8cfbb372ca548105b86fd6c3d \ + --hash=sha256:0e48694d6a9c5a26ee85b10130c77a011a4fedf50a7279fa0bdaf44bafb4299d \ + --hash=sha256:11b836b7dbfb74e049c302bbf74b4b8f6cb9d0b6ca1bf86cfa8ba144aedadd9c \ + --hash=sha256:19557fa407e70f20bfaba7d55b4d97b14f9480856c4fb65812e8a05fe1c6f9bf \ + --hash=sha256:229a25f68046c5cf8067d6d6351c8b99e40da11b04d8416bf8d2b1d75922521e \ + --hash=sha256:24216552104ae8f3b34120ef84825400b16eb6133af2e27a190fdc13529f023e \ + --hash=sha256:3b53d5769aa1f2d4ea407c65f2d1d08002952fac1d9e9d307aa2e1023554a163 \ + --hash=sha256:3de23ca4f381c3770dee6d10ead6fff524d540c0f662e763ad1530bde5112532 \ + --hash=sha256:407d864db716a067cc696d61fa1ef6637fedf03606e8417fe2aeed20a061e6b2 \ + --hash=sha256:41339b24471c58dc1499e56783fedc1afa4bb018bcd035cfb0ee2ad2a7501ef8 \ + --hash=sha256:462c59914dc6d81e0b11f37e560b8a7c2dbab6aca4f38be31519d442d6cde1a1 \ + --hash=sha256:46e24f5412c948d81736509377e255f6040e94216bf1a9b5ea1eaa9d29f6ec1b \ + --hash=sha256:498e53573e8b94b1caeb9e62d7c2d053c263ebb6aa259c81050766beb50ff8d9 \ + --hash=sha256:4ebf42695f75ee1a952f98ce9775c873e4971732a87334b099dde90b6af6a916 \ + --hash=sha256:4f9147051cb8fdb29a51dc2482d792b3b23e50f8f57e3720ca2e3d438b7adf23 \ + --hash=sha256:549174b0713d49871c6dee90a4b499d3f12f5e5f69641cd23c50a4542e2ca1eb \ + --hash=sha256:560f1d68a33e89c62da5da4077ba98137a5e4d3a271b29f2f195d0fba2adcb6a \ + --hash=sha256:566f0e41df06dfef2431defcfaa155f0acfa1ca4acbf8fd80895b1e7e2ada40e \ + --hash=sha256:56de98a2fb23025882a18b60c7f0ea2d2d70bbbcfcf878f9067234b1c4818442 \ + --hash=sha256:66544f853bfa85c0d07a68f6c648b2ec81dafd30f272565c37ab47a33b220684 \ + --hash=sha256:6c06e4c6e234fcc65435223c7b2a90f286b7f1b2733058bdf1345d218cc59e34 \ + --hash=sha256:6d0a8efc258659edc5299f9ef32d8d81de8b53b45d67bf4bfa3067f31366764d \ + --hash=sha256:70e5a10f8093d228bb2b552beeb318b8928b8a94763ef03b858ef3612b29395d \ + --hash=sha256:8394e652925a18ef0091115e3cc191fef350ab6dc3cc417f06da66bf98071ae9 \ + --hash=sha256:8636cd2fc5da0fb102a2504fa2c4bea3cbc149533b345d72cdf0e7a924decc45 \ + --hash=sha256:93df44ab351119d14cd1e6b52a5063d3336f0754b72736cc63db59307dabb718 \ + --hash=sha256:96ba37c2e24b7212a77da85004c38e7c4d155d3e72a45eeaf22c1f03f607e8ab \ + --hash=sha256:a10dab5ea1bd4401c9483450b5b0ba5416be799bbd50fc7a6cc5e2a15e03e8a3 \ + --hash=sha256:a66045af6cf00e19d02191ab578a50cb93b2028c3eefed999793698e9ea768ae \ + --hash=sha256:a75cc163a5f4531a256f2c523bd80db509a49fc23721b36dd1ef2f60ff41c3cb \ + --hash=sha256:b04c2f0adaf255bf756cf08ebef1be132d3c7a06fe6f9877d55640c5e60c72c5 \ + --hash=sha256:ba42e3810999a0ddd0439e6e5dbf6d034055cdc72b7c5c839f37a7c274cb4eba \ + --hash=sha256:bfc8a5e9238232a45ebc5cb3bfee71f1167064c8d382cadd6076f0d51cff1da0 \ + --hash=sha256:c5bd5680f844c3ff0008523a71949a3ff5e4953eb7701b28760805bc9bcff217 \ + --hash=sha256:c84fdf3da00c2827d634de4fcf17e3e067490c4aea82833625c4c8e6cdea0887 \ + --hash=sha256:ca6fab080484e419528e98624fb5c4282148b847e3602dc8dbe0cb0669469887 \ + --hash=sha256:d0c188ae66b772d9d61d43c6030500344c13e3f73a00d1dc241da896f379bb62 \ + --hash=sha256:d6ab42f223e58b7dac1bb0af32194a7b9311065583cc75ff59dcf301afd8a431 \ + --hash=sha256:dfe80c017973e6a4c367e037cb31601044dd55e6bfacd57370674867d15a899b \ + --hash=sha256:e0c02b75acfea5cab07585d25069207e478d12309557f90a61b5a3b4f77f46ce \ + --hash=sha256:e30aaf2b8a2bac57eb7e1650df1b3a4130e8d0c66fc2f861039d507a11760e1b \ + --hash=sha256:eafbef886566dc1047d7b3d4b14db0d5b7deb99638d8e1be4e23a7c7ac59ff0f \ + --hash=sha256:efe0fab26d598e1ec07d72cf03eaeeba8e42b4ecf6b9ccb5a356fde60ff08b85 \ + --hash=sha256:f08e469821a5e4751c97fcd34bcb586bc243c39c2e39321822060ba902eac49e \ + --hash=sha256:f1eaac5257a8f8a047248d60e8f9315c6cff58f7803971170d952555ef6344a7 \ + --hash=sha256:f29fb0b3f1217dfe9362ec55440d0743fe868497359f2cf93293f4b2701b8251 \ + --hash=sha256:f44d78b61740e4e8c71db1cf1fd56d9050a4747681c59ec1094750a658ceb970 \ + --hash=sha256:f6aec19457617ef468ff091669cca01fa7ea557b12b59a7908b9474bb9674cf0 \ + --hash=sha256:f9dc7f933975367251c1b34da882c4f0e0b2e24bb35dc906d2f598a40b72bfc7 + # via matplotlib +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c + # via matplotlib +dm-tree==0.1.8 \ + --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \ + --hash=sha256:0d3172394079a86c3a759179c65f64c48d1a42b89495fcf38976d11cc3bb952c \ + --hash=sha256:0e9620ccf06393eb6b613b5e366469304622d4ea96ae6540b28a33840e6c89cf \ + --hash=sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430 \ + --hash=sha256:1607ce49aa42f010d1e5e616d92ce899d66835d4d8bea49679582435285515de \ + --hash=sha256:181c35521d480d0365f39300542cb6cd7fd2b77351bb43d7acfda15aef63b317 \ + --hash=sha256:1d7c26e431fc93cc7e0cba867eb000db6a05f6f2b25af11ac4e9dada88fc5bca \ + --hash=sha256:1fe962015b2fe1282892b28ebe962faed53c7f98d942da9a4625cbf27baef913 \ + --hash=sha256:250b692fb75f45f02e2f58fbef9ab338904ef334b90557565621fa251df267cf \ + --hash=sha256:2869228d9c619074de501a3c10dc7f07c75422f8fab36ecdcb859b6f1b1ec3ef \ + --hash=sha256:28c52cbf4f8b3dbd0beaedf44f69fa85eec5e9dede612e08035e06ada6ec9426 \ + --hash=sha256:2f7915660f59c09068e428613c480150180df1060561fd0d1470684ae7007bd1 \ + --hash=sha256:343a4a4ebaa127451ff971254a4be4084eb4bdc0b2513c32b46f6f728fd03f9e \ + --hash=sha256:35cc164a79336bfcfafb47e5f297898359123bbd3330c1967f0c4994f9cf9f60 \ + --hash=sha256:378cc8ad93c5fe3590f405a309980721f021c790ca1bdf9b15bb1d59daec57f5 \ + --hash=sha256:39070ba268c0491af9fe7a58644d99e8b4f2cde6e5884ba3380bddc84ed43d5f \ + --hash=sha256:5483dca4d7eb1a0d65fe86d3b6a53ae717face83c1f17e0887b1a4a64ae5c410 \ + --hash=sha256:694c3654cfd2a81552c08ec66bb5c4a3d48fa292b9a181880fb081c36c5b9134 \ + --hash=sha256:803bfc53b4659f447ac694dbd04235f94a73ef7c1fd1e0df7c84ac41e0bc963b \ + --hash=sha256:81fce77f22a302d7a5968aebdf4efafef4def7ce96528719a354e6990dcd49c7 \ + --hash=sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393 \ + --hash=sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571 \ + --hash=sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368 \ + --hash=sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80 \ + --hash=sha256:ad16ceba90a56ec47cf45b21856d14962ac314787975ef786efb5e6e9ca75ec7 \ + --hash=sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d \ + --hash=sha256:b095ba4f8ca1ba19350fd53cf1f8f3eb0bd406aa28af64a6dfc86707b32a810a \ + --hash=sha256:b9bd9b9ccb59409d33d51d84b7668010c04c2af7d4a371632874c1ca356cff3d \ + --hash=sha256:b9f89a454e98806b44fe9d40ec9eee61f848388f7e79ac2371a55679bd5a3ac6 \ + --hash=sha256:bb2d109f42190225112da899b9f3d46d0d5f26aef501c61e43529fe9322530b5 \ + --hash=sha256:c5c8c12e3fda754ef6af94161bacdaeda816d941995fac415d6855c6c386af68 \ + --hash=sha256:d1612fcaecd79023dbc6a6ae48d51a80beb5c385d6f3f6d71688e57bc8d07de8 \ + --hash=sha256:d16e1f2a073604cfcc09f7131ae8d534674f43c3aef4c25742eae295bc60d04f \ + --hash=sha256:d20f2faa3672b52e5013f4077117bfb99c4cfc0b445d3bde1584c34032b57436 \ + --hash=sha256:d40fa4106ca6edc66760246a08f500ec0c85ef55c762fb4a363f6ee739ba02ee \ + --hash=sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb \ + --hash=sha256:e4d714371bb08839e4e5e29024fc95832d9affe129825ef38836b143028bd144 \ + --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ + --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d + # via + # chex + # tensorflow-datasets + # tfds-nightly +docstring-parser==0.15 \ + --hash=sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682 \ + --hash=sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9 + # via pyglove +editdistance==0.6.2 \ + --hash=sha256:01d63d31677c5f009eaed58b4a8b073613ced4c7aaca5221ab6269e735c95aec \ + --hash=sha256:01f1187385d7517ceaea0fdec1c57f78135a2507676f13ca8ab7ca3020cc64c6 \ + --hash=sha256:0395141203d51b118de3a6e90753278d5bf4ad66966950be6364790b55164828 \ + --hash=sha256:04681fa6d514620daadb8f380440cac509ff5316fc9867b30537a01396b2b8d7 \ + --hash=sha256:06f1db8dc8f09587e0beb5521d820dd002c57ced8f129528d72d6b8c0be85361 \ + --hash=sha256:0aa2adb28323b8c619c1b9c6bbd1ec8f150ef1fe7b06251107b6920d607b5f17 \ + --hash=sha256:0c1a76b06dd8ee37a83a96d5420f0cc35390a3b80a477e9f1dac898a520f9eca \ + --hash=sha256:11c5fc9c0e2b3563f8f58a97f4a8acbeb0e11fa1c1ae58161f19ff5b223911eb \ + --hash=sha256:1ac8fbdb787edd9aa65749be9abfbaff98d1c61da0de99c6225fa0bb3fe3a527 \ + --hash=sha256:1b6ab85adae46c8d191132cf67815434431be35b479b1a5844dce5e8d4d11b9d \ + --hash=sha256:1f3ef008568d50c24cf60c944120a26f5d19ebce65f030059288819738717b06 \ + --hash=sha256:22ab574cbec36552e8b0ae673feca08d9617bf94b837f9534fc2b82b9da38546 \ + --hash=sha256:2c48541b77a24acf548b5bf83189988bf4c93bc324503a6ea630453549017b46 \ + --hash=sha256:327efa37b801e45b3a12937230965c14ecefc79ab71b495cf38eb39ef70bb88c \ + --hash=sha256:33cf7b216b23ad6b949b531fee8e934474fa02ffaf1c55cc556931214ff37ef4 \ + --hash=sha256:3435165327074a4321484e2de06e4b74ef03c34e103c39eee126f7277bb8123f \ + --hash=sha256:347bce77dac7e15bd6dbbfb5f09da0a5d5e64f6df3d5409abb3371b8005652d7 \ + --hash=sha256:3842411d9126249db7b52c61ce22571dc91292b252c7a7a71ebf34805fc8710f \ + --hash=sha256:3ccab6307fa6299da911a6c14c1914086754dde9f32ec7979a98656e6e5e18b7 \ + --hash=sha256:3d19ecb414072a44c7d10a239c8494aea0faa476bc541766b7a3228fc9c19943 \ + --hash=sha256:454597f8f4bc0d680023753cc28abb0bf288527eddfb6cad43666efe388c91a2 \ + --hash=sha256:47a58a2a0e4bd63d33c1541f3f6c2f6b4c9699c2684ef7beeb20d93768d89a8b \ + --hash=sha256:50234341a9264a91f72136245404cb319b92010b0bd951decdcdb8e809942fca \ + --hash=sha256:51fe3eb5fcb14071735eadf8daf8583126967aec82965fe9d4f928f2b909c310 \ + --hash=sha256:53b1dee4391d83f1fd564d92373c075b473d351b25dc3ccb2b225331cfa7cd57 \ + --hash=sha256:57f3b6ecf872be8b678128f283eb0bb47220c422105e48d0351408fee9cc617a \ + --hash=sha256:63b3400b39b058bc2f58f1e4b373276d9795adc397e502efc74b3f24543c936a \ + --hash=sha256:63c73f8c38579c5f6464e129c791a5bfa88b0f2c1f8ffd15620125cfe580ecae \ + --hash=sha256:6537e4f86dc8437abff129d32795a813aa7c08758d03196fde8e5701de497b97 \ + --hash=sha256:65ce074c366e483b190832be4c1194d8f667b0d2a9747241fee476a06d01c0c5 \ + --hash=sha256:66eea6ad1a600e620a17d5b46d29796a79287cfa98db39eaf0aaccf79f6f7552 \ + --hash=sha256:69662c3ca546a6df37bb775b53aad9a6426b224a3e7e1a3062c9f01e88e94d0f \ + --hash=sha256:6adf991c0478de93a7b77dfc11317399248203ba0ff5a7f775e340a03869bf21 \ + --hash=sha256:6b282014742809684c720b9d5d2afb3cc542699e4515bf59a94512f656a36886 \ + --hash=sha256:71aac2d0b4b421aac798bfbcce65bb2aafa0087ab3de4ef986218968f85bd531 \ + --hash=sha256:727dc0c4bb282e7c9efa4dbd385d78f8a7fd7edad068a143a57a72d931f6e45e \ + --hash=sha256:793b55100ec226ed41eeb8ed7dd66f4b30063b601070ff64405351c6871bf949 \ + --hash=sha256:7a7a0b53640447746d9842fbe6b2766c76f92f9c9b554f45b9e665247b1de2f6 \ + --hash=sha256:7b5ddba4883c22a2ad1137cab6d009772cfce1fa843f4b992e1f6834ad10707c \ + --hash=sha256:7fe35c130d8f0e740ec70b6672ac3cfc289625f0ab4c4e9a264583cb5fcc8d83 \ + --hash=sha256:84c50dda926486b5be08aab6afbdc9e7d440c52f7f55e5b0b54efcb7c0742001 \ + --hash=sha256:85d8e5a67a9b5d4d0bf6bcc3e8c1fdbc11f6357def8c2df5a4cea5e19874c79a \ + --hash=sha256:8c269be2a1325759ebc4f6698501349cf1248d7d09f5b239d2780b9085c24c98 \ + --hash=sha256:8da55507e34468646e05bb96fc046c46c882649097bbbe45d4ef785a0ac3f3a4 \ + --hash=sha256:8fdf4834d0b3ec287f92b124889f755a6d28ce97e407147f015af3237290bb9a \ + --hash=sha256:9148a887b2a802c82712671b90618892a657ee8303853e4733ccc59e0ecc7c65 \ + --hash=sha256:931fb160fd3e741d880f1bef46369993b50a45449e5b9c8f2e753df446c6bba5 \ + --hash=sha256:979c1798b0b3ea4a26e570d55f6a5014d5e6c8b3129e927a22047ea54621d188 \ + --hash=sha256:97a722f5e859ed4c26da269e71a11995f23ac9c880618b8a2028373eb74283be \ + --hash=sha256:97fdc521d04b72e7f0bb393283091eaac1def3eaf12295aa4c7627d2beb99ed5 \ + --hash=sha256:996c1d43361c661391f18b1920e28a53a2702de7cd9baedc0d61f91ad00e9403 \ + --hash=sha256:9d76a7c0f4f4d81a964118e5ea596948dcfaa85e1d81ecdf8ceac190d669a1d7 \ + --hash=sha256:a69254caa2874087cfda056e25691e215fbb92895971ab8b19db5b0b105f47ea \ + --hash=sha256:abe4da756fadc5c84745b6b07794d504968aebb280931909e3b5e92ea434fbe0 \ + --hash=sha256:aed1aa6b3ba6a17d623ac5115412023b3c52c85727a97cdae14bd83489d1c173 \ + --hash=sha256:b26bdd02a7243cca646e69ca1ab1e28b229fd838271fb1413d55afc89f141033 \ + --hash=sha256:b4fed965589ffd1f382ba26d06811427c57621a93cae3e31a986bc4c8fa7a716 \ + --hash=sha256:b7c1f7c81ee85774f738bd1fbdeeeae3b7853301271c93308f7f26fa830600e3 \ + --hash=sha256:bce6b5191a1dac919d06357179654b24304f40a22b2cfc5d631210cb05964d34 \ + --hash=sha256:c5c4b3327ba03841700b41e1393580732b7a15d553c47890aef4651fd42000cc \ + --hash=sha256:c91cc8f69353811732f50465251ecccb87b2d7a5127ee739370f797a1e83c927 \ + --hash=sha256:cad240c00f0e6a3a6a23ac962585e14fcd8fa8ec041b5a57937acefa8beed658 \ + --hash=sha256:d5cb6f5464823317b8884993e9b60fc590af950bc90f318913de10b85a9a2a44 \ + --hash=sha256:d917fa9d82f51b57a8d23abbc03f49640488cbfe4b9d6351ba3b562cadd95f6b \ + --hash=sha256:d9971beaf82c1ea6a5802daaac069a92b43309bbbc016fdadc9a54b1c4ce840f \ + --hash=sha256:dc63126b80320abbd6070ba9f82fed0a7d4388095324982ba5d7112a8c783abc \ + --hash=sha256:ddc02826be4a17fca0833ecc08a4ff1e6ec037a8a4b18567b2628c8ff309e8df \ + --hash=sha256:e17d5e2a14fd2fe851d370576afb975dbc8a098c4ff6a858462c067ebd015849 \ + --hash=sha256:e375ed11f06f47f29e6fcb3e23fd4d1abe8310f1b666eb795b7bc2f343d36000 \ + --hash=sha256:e609ef6ca83ade4daf4a7a4dbc9154a0a5000b33796c04ff295d1a302f43af19 \ + --hash=sha256:e9a39d7567fa71f6304a1c97dfb8b34773c7df78968db8a26a723043b9c96251 \ + --hash=sha256:eab0fc3a1c3b4b3518f64aa328f9ef5c382413343c99bf566a4cd5b45a4cf97d \ + --hash=sha256:ede7a5a67f35cc0f7eb4b7230cbe99e80d54b127794fff1415b2496084ac0117 \ + --hash=sha256:ef96fb4e43362cdf723d243045fcbb87c0d1a43fe0797b168b8a659bf8e6a3c1 \ + --hash=sha256:f2f5569e7a870d7dfc00c301688d03cbe031bf4a83f390886d410c82315377be \ + --hash=sha256:f32890793c2de47968caff55af277d3e4f2f536f38afd5755b2508f0f4f338a9 \ + --hash=sha256:f5b4095f5db7cdcc67b6df7a87a40c120b67e7ce63fa38d1900fc76271b6405b \ + --hash=sha256:f8d72d4283ff14cd4c9d3d38030f159f71fea2ceb611fcdbad1022e912262dfa + # via + # seqio + # seqio-nightly + # t5 +etils[array-types,enp,epath,epy,etqdm,etree]==1.5.1 \ + --hash=sha256:2c1bfa2817eb4881cb509097f1e65ac6160126ba74ec47b3bb47ee678628d8c8 \ + --hash=sha256:b530c0d1b2ed1b8da1af367d4b97891e680a6a4658d4190183210bbbb8cf1fb9 + # via + # array-record + # clu + # tensorflow-datasets + # tfds-nightly +filelock==3.12.4 \ + --hash=sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4 \ + --hash=sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd + # via + # gdown + # huggingface-hub + # transformers +flatbuffers==23.5.26 \ + --hash=sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89 \ + --hash=sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1 + # via tensorflow +flax==0.5.3 \ + --hash=sha256:7ab7f75abc9dd7568ee86279c2e167bbd7ecbeeeac03589d93119b3321036cc0 \ + --hash=sha256:911dc4d6463ccba9808d303f44dfb4a0a6de277be8f3cbfda1471ed1a3e19374 + # via + # -r requirements.in + # clu +fonttools==4.43.1 \ + --hash=sha256:10003ebd81fec0192c889e63a9c8c63f88c7d72ae0460b7ba0cd2a1db246e5ad \ + --hash=sha256:10b3922875ffcba636674f406f9ab9a559564fdbaa253d66222019d569db869c \ + --hash=sha256:13a9a185259ed144def3682f74fdcf6596f2294e56fe62dfd2be736674500dba \ + --hash=sha256:17dbc2eeafb38d5d0e865dcce16e313c58265a6d2d20081c435f84dc5a9d8212 \ + --hash=sha256:18a2477c62a728f4d6e88c45ee9ee0229405e7267d7d79ce1f5ce0f3e9f8ab86 \ + --hash=sha256:18eefac1b247049a3a44bcd6e8c8fd8b97f3cad6f728173b5d81dced12d6c477 \ + --hash=sha256:1952c89a45caceedf2ab2506d9a95756e12b235c7182a7a0fff4f5e52227204f \ + --hash=sha256:1cf9e974f63b1080b1d2686180fc1fbfd3bfcfa3e1128695b5de337eb9075cef \ + --hash=sha256:1e09da7e8519e336239fbd375156488a4c4945f11c4c5792ee086dd84f784d02 \ + --hash=sha256:2062542a7565091cea4cc14dd99feff473268b5b8afdee564f7067dd9fff5860 \ + --hash=sha256:25d3da8a01442cbc1106490eddb6d31d7dffb38c1edbfabbcc8db371b3386d72 \ + --hash=sha256:34f713dad41aa21c637b4e04fe507c36b986a40f7179dcc86402237e2d39dcd3 \ + --hash=sha256:360201d46165fc0753229afe785900bc9596ee6974833124f4e5e9f98d0f592b \ + --hash=sha256:3b7ad05b2beeebafb86aa01982e9768d61c2232f16470f9d0d8e385798e37184 \ + --hash=sha256:4c54466f642d2116686268c3e5f35ebb10e49b0d48d41a847f0e171c785f7ac7 \ + --hash=sha256:4d9740e3783c748521e77d3c397dc0662062c88fd93600a3c2087d3d627cd5e5 \ + --hash=sha256:4f88cae635bfe4bbbdc29d479a297bb525a94889184bb69fa9560c2d4834ddb9 \ + --hash=sha256:51669b60ee2a4ad6c7fc17539a43ffffc8ef69fd5dbed186a38a79c0ac1f5db7 \ + --hash=sha256:5db46659cfe4e321158de74c6f71617e65dc92e54980086823a207f1c1c0e24b \ + --hash=sha256:5f37e31291bf99a63328668bb83b0669f2688f329c4c0d80643acee6e63cd933 \ + --hash=sha256:6bb5ea9076e0e39defa2c325fc086593ae582088e91c0746bee7a5a197be3da0 \ + --hash=sha256:748015d6f28f704e7d95cd3c808b483c5fb87fd3eefe172a9da54746ad56bfb6 \ + --hash=sha256:7bbbf8174501285049e64d174e29f9578495e1b3b16c07c31910d55ad57683d8 \ + --hash=sha256:884ef38a5a2fd47b0c1291647b15f4e88b9de5338ffa24ee52c77d52b4dfd09c \ + --hash=sha256:8da417431bfc9885a505e86ba706f03f598c85f5a9c54f67d63e84b9948ce590 \ + --hash=sha256:95e974d70238fc2be5f444fa91f6347191d0e914d5d8ae002c9aa189572cc215 \ + --hash=sha256:9648518ef687ba818db3fcc5d9aae27a369253ac09a81ed25c3867e8657a0680 \ + --hash=sha256:9a2f0aa6ca7c9bc1058a9d0b35483d4216e0c1bbe3962bc62ce112749954c7b8 \ + --hash=sha256:9c36da88422e0270fbc7fd959dc9749d31a958506c1d000e16703c2fce43e3d0 \ + --hash=sha256:9c60ecfa62839f7184f741d0509b5c039d391c3aff71dc5bc57b87cc305cff3b \ + --hash=sha256:9f727c3e3d08fd25352ed76cc3cb61486f8ed3f46109edf39e5a60fc9fecf6ca \ + --hash=sha256:a7a06f8d95b7496e53af80d974d63516ffb263a468e614978f3899a6df52d4b3 \ + --hash=sha256:ad0b3f6342cfa14be996971ea2b28b125ad681c6277c4cd0fbdb50340220dfb6 \ + --hash=sha256:b2adca1b46d69dce4a37eecc096fe01a65d81a2f5c13b25ad54d5430ae430b13 \ + --hash=sha256:b84a1c00f832feb9d0585ca8432fba104c819e42ff685fcce83537e2e7e91204 \ + --hash=sha256:bb6d2f8ef81ea076877d76acfb6f9534a9c5f31dc94ba70ad001267ac3a8e56f \ + --hash=sha256:bf11e2cca121df35e295bd34b309046c29476ee739753bc6bc9d5050de319273 \ + --hash=sha256:d21099b411e2006d3c3e1f9aaf339e12037dbf7bf9337faf0e93ec915991f43b \ + --hash=sha256:d4071bd1c183b8d0b368cc9ed3c07a0f6eb1bdfc4941c4c024c49a35429ac7cd \ + --hash=sha256:e117a92b07407a061cde48158c03587ab97e74e7d73cb65e6aadb17af191162a \ + --hash=sha256:f7a58eb5e736d7cf198eee94844b81c9573102ae5989ebcaa1d1a37acd04b33d \ + --hash=sha256:fe9b1ec799b6086460a7480e0f55c447b1aca0a4eecc53e444f639e967348896 + # via matplotlib +fsspec==2023.9.2 \ + --hash=sha256:603dbc52c75b84da501b9b2ec8c11e1f61c25984c4a0dda1f129ef391fbfc9b4 \ + --hash=sha256:80bfb8c70cc27b2178cc62a935ecf242fc6e8c3fb801f9c571fc01b1e715ba7d + # via + # etils + # huggingface-hub +future==0.18.3 \ + --hash=sha256:34a17436ed1e96697a86f9de3d15a3b0be01d8bc8de9c1dffd59fb8234ed5307 + # via mesh-tensorflow +gast==0.4.0 \ + --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ + --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 + # via tensorflow +gdown==4.7.1 \ + --hash=sha256:347f23769679aaf7efa73e5655270fcda8ca56be65eb84a4a21d143989541045 \ + --hash=sha256:65d495699e7c2c61af0d0e9c32748fb4f79abaf80d747a87456c7be14aac2560 + # via -r requirements.in +gin==0.1.6 \ + --hash=sha256:0747da840881792f1726f9145094953b0a1499e9b41324a14ca6a10c03baa1ef + # via -r requirements.in +gin-config==0.5.0 \ + --hash=sha256:0c6ea5026ded927c8c93c990b01c695257c1df446e45e549a158cfbc79e19ed6 \ + --hash=sha256:bddb7ca221ea2b46cdb59321e79fecf02d6e3b728906047fcd4076c297609fd6 + # via + # -r requirements.in + # mesh-tensorflow + # t5 +google-auth==2.23.3 \ + --hash=sha256:6864247895eea5d13b9c57c9e03abb49cb94ce2dc7c58e91cba3248c7477c9e3 \ + --hash=sha256:a8f4608e65c244ead9e0538f181a96c6e11199ec114d41f1d7b1bffa96937bda + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==1.0.0 \ + --hash=sha256:95880ca704928c300f48194d1770cf5b1462835b6e49db61445a520f793fd5fb \ + --hash=sha256:e375064964820b47221a7e1b7ee1fd77051b6323c3f9e3e19785f78ab67ecfc5 + # via tensorboard +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via tensorflow +googleapis-common-protos==1.61.0 \ + --hash=sha256:22f1915393bb3245343f6efe87f6fe868532efc12aa26b391b15132e1279f1c0 \ + --hash=sha256:8a64866a97f6304a7179873a465d6eee97b7a24ec6cfd78e0f575e96b821240b + # via tensorflow-metadata +grpcio==1.59.0 \ + --hash=sha256:0ae444221b2c16d8211b55326f8ba173ba8f8c76349bfc1768198ba592b58f74 \ + --hash=sha256:0b84445fa94d59e6806c10266b977f92fa997db3585f125d6b751af02ff8b9fe \ + --hash=sha256:14890da86a0c0e9dc1ea8e90101d7a3e0e7b1e71f4487fab36e2bfd2ecadd13c \ + --hash=sha256:15f03bd714f987d48ae57fe092cf81960ae36da4e520e729392a59a75cda4f29 \ + --hash=sha256:1a839ba86764cc48226f50b924216000c79779c563a301586a107bda9cbe9dcf \ + --hash=sha256:225e5fa61c35eeaebb4e7491cd2d768cd8eb6ed00f2664fa83a58f29418b39fd \ + --hash=sha256:228b91ce454876d7eed74041aff24a8f04c0306b7250a2da99d35dd25e2a1211 \ + --hash=sha256:2ea95cd6abbe20138b8df965b4a8674ec312aaef3147c0f46a0bac661f09e8d0 \ + --hash=sha256:2f120d27051e4c59db2f267b71b833796770d3ea36ca712befa8c5fff5da6ebd \ + --hash=sha256:34341d9e81a4b669a5f5dca3b2a760b6798e95cdda2b173e65d29d0b16692857 \ + --hash=sha256:3859917de234a0a2a52132489c4425a73669de9c458b01c9a83687f1f31b5b10 \ + --hash=sha256:38823bd088c69f59966f594d087d3a929d1ef310506bee9e3648317660d65b81 \ + --hash=sha256:38da5310ef84e16d638ad89550b5b9424df508fd5c7b968b90eb9629ca9be4b9 \ + --hash=sha256:3b8ff795d35a93d1df6531f31c1502673d1cebeeba93d0f9bd74617381507e3f \ + --hash=sha256:50eff97397e29eeee5df106ea1afce3ee134d567aa2c8e04fabab05c79d791a7 \ + --hash=sha256:5711c51e204dc52065f4a3327dca46e69636a0b76d3e98c2c28c4ccef9b04c52 \ + --hash=sha256:598f3530231cf10ae03f4ab92d48c3be1fee0c52213a1d5958df1a90957e6a88 \ + --hash=sha256:611d9aa0017fa386809bddcb76653a5ab18c264faf4d9ff35cb904d44745f575 \ + --hash=sha256:61bc72a00ecc2b79d9695220b4d02e8ba53b702b42411397e831c9b0589f08a3 \ + --hash=sha256:63982150a7d598281fa1d7ffead6096e543ff8be189d3235dd2b5604f2c553e5 \ + --hash=sha256:6c4b1cc3a9dc1924d2eb26eec8792fedd4b3fcd10111e26c1d551f2e4eda79ce \ + --hash=sha256:81d86a096ccd24a57fa5772a544c9e566218bc4de49e8c909882dae9d73392df \ + --hash=sha256:849c47ef42424c86af069a9c5e691a765e304079755d5c29eff511263fad9c2a \ + --hash=sha256:871371ce0c0055d3db2a86fdebd1e1d647cf21a8912acc30052660297a5a6901 \ + --hash=sha256:8cd2d38c2d52f607d75a74143113174c36d8a416d9472415eab834f837580cf7 \ + --hash=sha256:936b2e04663660c600d5173bc2cc84e15adbad9c8f71946eb833b0afc205b996 \ + --hash=sha256:93e9cb546e610829e462147ce724a9cb108e61647a3454500438a6deef610be1 \ + --hash=sha256:956f0b7cb465a65de1bd90d5a7475b4dc55089b25042fe0f6c870707e9aabb1d \ + --hash=sha256:986de4aa75646e963466b386a8c5055c8b23a26a36a6c99052385d6fe8aaf180 \ + --hash=sha256:aca8a24fef80bef73f83eb8153f5f5a0134d9539b4c436a716256b311dda90a6 \ + --hash=sha256:acf70a63cf09dd494000007b798aff88a436e1c03b394995ce450be437b8e54f \ + --hash=sha256:b34c7a4c31841a2ea27246a05eed8a80c319bfc0d3e644412ec9ce437105ff6c \ + --hash=sha256:b95ec8ecc4f703f5caaa8d96e93e40c7f589bad299a2617bdb8becbcce525539 \ + --hash=sha256:ba0ca727a173ee093f49ead932c051af463258b4b493b956a2c099696f38aa66 \ + --hash=sha256:c041a91712bf23b2a910f61e16565a05869e505dc5a5c025d429ca6de5de842c \ + --hash=sha256:c0488c2b0528e6072010182075615620071371701733c63ab5be49140ed8f7f0 \ + --hash=sha256:c173a87d622ea074ce79be33b952f0b424fa92182063c3bda8625c11d3585d09 \ + --hash=sha256:c251d22de8f9f5cca9ee47e4bade7c5c853e6e40743f47f5cc02288ee7a87252 \ + --hash=sha256:c4dfdb49f4997dc664f30116af2d34751b91aa031f8c8ee251ce4dcfc11277b0 \ + --hash=sha256:ca87ee6183421b7cea3544190061f6c1c3dfc959e0b57a5286b108511fd34ff4 \ + --hash=sha256:ceb1e68135788c3fce2211de86a7597591f0b9a0d2bb80e8401fd1d915991bac \ + --hash=sha256:d09bd2a4e9f5a44d36bb8684f284835c14d30c22d8ec92ce796655af12163588 \ + --hash=sha256:d0fcf53df684fcc0154b1e61f6b4a8c4cf5f49d98a63511e3f30966feff39cd0 \ + --hash=sha256:d74f7d2d7c242a6af9d4d069552ec3669965b74fed6b92946e0e13b4168374f9 \ + --hash=sha256:de2599985b7c1b4ce7526e15c969d66b93687571aa008ca749d6235d056b7205 \ + --hash=sha256:e5378785dce2b91eb2e5b857ec7602305a3b5cf78311767146464bfa365fc897 \ + --hash=sha256:ec78aebb9b6771d6a1de7b6ca2f779a2f6113b9108d486e904bde323d51f5589 \ + --hash=sha256:f1feb034321ae2f718172d86b8276c03599846dc7bb1792ae370af02718f91c5 \ + --hash=sha256:f21917aa50b40842b51aff2de6ebf9e2f6af3fe0971c31960ad6a3a2b24988f4 \ + --hash=sha256:f367e4b524cb319e50acbdea57bb63c3b717c5d561974ace0b065a648bb3bad3 \ + --hash=sha256:f6cfe44a5d7c7d5f1017a7da1c8160304091ca5dc64a0f85bca0d63008c3137a \ + --hash=sha256:fa66cac32861500f280bb60fe7d5b3e22d68c51e18e65367e38f8669b78cea3b \ + --hash=sha256:fc8bf2e7bc725e76c0c11e474634a08c8f24bcf7426c0c6d60c8f9c6e70e4d4a \ + --hash=sha256:fe976910de34d21057bcb53b2c5e667843588b48bf11339da2a75f5c4c5b4055 + # via + # tensorboard + # tensorflow +h5py==3.10.0 \ + --hash=sha256:012ab448590e3c4f5a8dd0f3533255bc57f80629bf7c5054cf4c87b30085063c \ + --hash=sha256:212bb997a91e6a895ce5e2f365ba764debeaef5d2dca5c6fb7098d66607adf99 \ + --hash=sha256:2381e98af081b6df7f6db300cd88f88e740649d77736e4b53db522d8874bf2dc \ + --hash=sha256:2c8e4fda19eb769e9a678592e67eaec3a2f069f7570c82d2da909c077aa94339 \ + --hash=sha256:3074ec45d3dc6e178c6f96834cf8108bf4a60ccb5ab044e16909580352010a97 \ + --hash=sha256:3c97d03f87f215e7759a354460fb4b0d0f27001450b18b23e556e7856a0b21c3 \ + --hash=sha256:43a61b2c2ad65b1fabc28802d133eed34debcc2c8b420cb213d3d4ef4d3e2229 \ + --hash=sha256:492305a074327e8d2513011fa9fffeb54ecb28a04ca4c4227d7e1e9616d35641 \ + --hash=sha256:5dfc65ac21fa2f630323c92453cadbe8d4f504726ec42f6a56cf80c2f90d6c52 \ + --hash=sha256:667fe23ab33d5a8a6b77970b229e14ae3bb84e4ea3382cc08567a02e1499eedd \ + --hash=sha256:6c013d2e79c00f28ffd0cc24e68665ea03ae9069e167087b2adb5727d2736a52 \ + --hash=sha256:781a24263c1270a62cd67be59f293e62b76acfcc207afa6384961762bb88ea03 \ + --hash=sha256:86df4c2de68257b8539a18646ceccdcf2c1ce6b1768ada16c8dcfb489eafae20 \ + --hash=sha256:90286b79abd085e4e65e07c1bd7ee65a0f15818ea107f44b175d2dfe1a4674b7 \ + --hash=sha256:92273ce69ae4983dadb898fd4d3bea5eb90820df953b401282ee69ad648df684 \ + --hash=sha256:93dd840bd675787fc0b016f7a05fc6efe37312a08849d9dd4053fd0377b1357f \ + --hash=sha256:9450464b458cca2c86252b624279115dcaa7260a40d3cb1594bf2b410a2bd1a3 \ + --hash=sha256:ae2f0201c950059676455daf92700eeb57dcf5caaf71b9e1328e6e6593601770 \ + --hash=sha256:aece0e2e1ed2aab076c41802e50a0c3e5ef8816d60ece39107d68717d4559824 \ + --hash=sha256:b963fb772964fc1d1563c57e4e2e874022ce11f75ddc6df1a626f42bd49ab99f \ + --hash=sha256:ba9ab36be991119a3ff32d0c7cbe5faf9b8d2375b5278b2aea64effbeba66039 \ + --hash=sha256:d4682b94fd36ab217352be438abd44c8f357c5449b8995e63886b431d260f3d3 \ + --hash=sha256:d93adc48ceeb33347eb24a634fb787efc7ae4644e6ea4ba733d099605045c049 \ + --hash=sha256:f42e6c30698b520f0295d70157c4e202a9e402406f50dc08f5a7bc416b24e52d \ + --hash=sha256:fd6f6d1384a9f491732cee233b99cd4bfd6e838a8815cc86722f9d2ee64032af + # via tensorflow +huggingface-hub==0.17.3 \ + --hash=sha256:40439632b211311f788964602bf8b0d9d6b7a2314fba4e8d67b2ce3ecea0e3fd \ + --hash=sha256:545eb3665f6ac587add946e73984148f2ea5c7877eac2e845549730570c1933a + # via + # tokenizers + # transformers +idna==3.4 \ + --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ + --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 + # via requests +immutabledict==3.0.0 \ + --hash=sha256:034bacc6c6872707c4ec0ea9515de6bbe0dcf0fcabd97ae19fd4e4c338f05798 \ + --hash=sha256:5a23cd369a6187f76a8c29d7d687980b092538eb9800e58964603f1b973c56fe + # via t5 +importlib-resources==6.1.0 \ + --hash=sha256:9d48dcccc213325e810fd723e7fbb45ccb39f6cf5c31f00cf2b965f5f10f3cb9 \ + --hash=sha256:aa50258bbfa56d4e33fbd8aa3ef48ded10d1735f11532b8df95388cc6bdb7e83 + # via etils +jax==0.4.6 \ + --hash=sha256:d06ea8fba4ed315ec55110396058cb48c8edb2ab0b412f28c8a123beee9e58ab + # via + # -r requirements.in + # chex + # clu + # flax + # optax + # seqio + # seqio-nightly +jaxlib==0.4.6 \ + --hash=sha256:1393968b0f808c1769195990f1ea138903bc4012bdffb850eecd10e113b8fca8 \ + --hash=sha256:2949b6b6b77f296982b42ad2c6350526baf47b0f105118a65a9e9b2093de6572 \ + --hash=sha256:34b0b4e41e185fba36b81cf68d4979503afba4640bf29b7f6709b17b3c3c55bc \ + --hash=sha256:399d83e16a35d66693b27951b87be566d91e47c7f4ac1fc5a362536a7b9c29cc \ + --hash=sha256:3be2b70104f9547a3281e5c06dfafe1be27c4927d6b62b69a55d26977cd03e15 \ + --hash=sha256:4079c42247db33c69f10710b6a5a570ed89d773e0e549612fadba3d06fe4773c \ + --hash=sha256:554fc7300c61d76b5996145bd33a8dd19d60c49e57ad38686057d719b1d69d38 \ + --hash=sha256:c42ea44671f1e560d63f742c787a65744f1efd0a20bbfe177e9d3e8bd7cece92 \ + --hash=sha256:cdfbff50bae46065d2fac6250260077e2c554df52252c45cb5ca949bed378b6f \ + --hash=sha256:e18e3d5fd5d1aee94bd97791c7157ea7fd682f5eb8a04a8b1a3b0ed011175892 \ + --hash=sha256:ed2f504f0d48a1e727322aa5baae0ec0405fc5fcb5e4135bb15740978535b5e0 \ + --hash=sha256:f7233463e1f79b330d3a1e12629e1bbc334acf6f5be22a0af244a2ed544afdfe + # via + # -r requirements.in + # chex + # clu + # optax + # seqio + # seqio-nightly +joblib==1.3.2 \ + --hash=sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1 \ + --hash=sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9 + # via + # nltk + # scikit-learn +keras==2.13.1 \ + --hash=sha256:5ce5f706f779fa7330e63632f327b75ce38144a120376b2ae1917c00fa6136af \ + --hash=sha256:5df12cc241a015a11b65ddb452c0eeb2744fce21d9b54ba48db87492568ccc68 + # via tensorflow +kiwisolver==1.4.5 \ + --hash=sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf \ + --hash=sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e \ + --hash=sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af \ + --hash=sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f \ + --hash=sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046 \ + --hash=sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3 \ + --hash=sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5 \ + --hash=sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71 \ + --hash=sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee \ + --hash=sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3 \ + --hash=sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9 \ + --hash=sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b \ + --hash=sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985 \ + --hash=sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea \ + --hash=sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16 \ + --hash=sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89 \ + --hash=sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c \ + --hash=sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9 \ + --hash=sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712 \ + --hash=sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342 \ + --hash=sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a \ + --hash=sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958 \ + --hash=sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d \ + --hash=sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a \ + --hash=sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130 \ + --hash=sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff \ + --hash=sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898 \ + --hash=sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b \ + --hash=sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f \ + --hash=sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265 \ + --hash=sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93 \ + --hash=sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929 \ + --hash=sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635 \ + --hash=sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709 \ + --hash=sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b \ + --hash=sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb \ + --hash=sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a \ + --hash=sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920 \ + --hash=sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e \ + --hash=sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544 \ + --hash=sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45 \ + --hash=sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390 \ + --hash=sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77 \ + --hash=sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355 \ + --hash=sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff \ + --hash=sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4 \ + --hash=sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7 \ + --hash=sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20 \ + --hash=sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c \ + --hash=sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162 \ + --hash=sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228 \ + --hash=sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437 \ + --hash=sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc \ + --hash=sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a \ + --hash=sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901 \ + --hash=sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4 \ + --hash=sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770 \ + --hash=sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525 \ + --hash=sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad \ + --hash=sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a \ + --hash=sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29 \ + --hash=sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90 \ + --hash=sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250 \ + --hash=sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d \ + --hash=sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3 \ + --hash=sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54 \ + --hash=sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f \ + --hash=sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1 \ + --hash=sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da \ + --hash=sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238 \ + --hash=sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa \ + --hash=sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523 \ + --hash=sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0 \ + --hash=sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205 \ + --hash=sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3 \ + --hash=sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4 \ + --hash=sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac \ + --hash=sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9 \ + --hash=sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb \ + --hash=sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced \ + --hash=sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd \ + --hash=sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0 \ + --hash=sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da \ + --hash=sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18 \ + --hash=sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9 \ + --hash=sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276 \ + --hash=sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333 \ + --hash=sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b \ + --hash=sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db \ + --hash=sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126 \ + --hash=sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9 \ + --hash=sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09 \ + --hash=sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0 \ + --hash=sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec \ + --hash=sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7 \ + --hash=sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff \ + --hash=sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9 \ + --hash=sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192 \ + --hash=sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8 \ + --hash=sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d \ + --hash=sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6 \ + --hash=sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797 \ + --hash=sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892 \ + --hash=sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f + # via matplotlib +libclang==16.0.6 \ + --hash=sha256:1e940048f51d0b0999099a9b78629ab8a64b62af5e9ff1b2b062439c21ee244d \ + --hash=sha256:4a9acbfd9c135a72f80d5dbff7588dfb0c81458244a89b9e83526e8595880e0a \ + --hash=sha256:4acdde39dfe410c877b4ccc0d4b57eb952100e4ee26bbdf6cfdb88e2033a7d31 \ + --hash=sha256:8130482120500476a027171f8f3c8dfc2536b591716eea71fc5da22cae13131b \ + --hash=sha256:88bc7e7b393c32e41e03ba77ef02fdd647da1f764c2cd028e69e0837080b79f6 \ + --hash=sha256:9dcdc730939788b8b69ffd6d5d75fe5366e3ee007f1e36a99799ec0b0c001492 \ + --hash=sha256:d80ed5827736ed5ec2bcedf536720476fd9d4fa4c79ef0cb24aea4c59332f361 \ + --hash=sha256:da9e47ebc3f0a6d90fb169ef25f9fbcd29b4a4ef97a8b0e3e3a17800af1423f4 \ + --hash=sha256:daab4a11dae228f1efa9efa3fe638b493b14d8d52c71fb3c7019e2f1df4514c2 \ + --hash=sha256:e1a5ad1e895e5443e205568c85c04b4608e4e973dae42f4dfd9cb46c81d1486b \ + --hash=sha256:f04e3060ae1f207f234d0608900c99c50edcb743e5e18276d78da2ddd727d39f + # via tensorflow +lxml==4.9.3 \ + --hash=sha256:05186a0f1346ae12553d66df1cfce6f251589fea3ad3da4f3ef4e34b2d58c6a3 \ + --hash=sha256:075b731ddd9e7f68ad24c635374211376aa05a281673ede86cbe1d1b3455279d \ + --hash=sha256:081d32421db5df44c41b7f08a334a090a545c54ba977e47fd7cc2deece78809a \ + --hash=sha256:0a3d3487f07c1d7f150894c238299934a2a074ef590b583103a45002035be120 \ + --hash=sha256:0bfd0767c5c1de2551a120673b72e5d4b628737cb05414f03c3277bf9bed3305 \ + --hash=sha256:0c0850c8b02c298d3c7006b23e98249515ac57430e16a166873fc47a5d549287 \ + --hash=sha256:0e2cb47860da1f7e9a5256254b74ae331687b9672dfa780eed355c4c9c3dbd23 \ + --hash=sha256:120fa9349a24c7043854c53cae8cec227e1f79195a7493e09e0c12e29f918e52 \ + --hash=sha256:1247694b26342a7bf47c02e513d32225ededd18045264d40758abeb3c838a51f \ + --hash=sha256:141f1d1a9b663c679dc524af3ea1773e618907e96075262726c7612c02b149a4 \ + --hash=sha256:14e019fd83b831b2e61baed40cab76222139926b1fb5ed0e79225bc0cae14584 \ + --hash=sha256:1509dd12b773c02acd154582088820893109f6ca27ef7291b003d0e81666109f \ + --hash=sha256:17a753023436a18e27dd7769e798ce302963c236bc4114ceee5b25c18c52c693 \ + --hash=sha256:1e224d5755dba2f4a9498e150c43792392ac9b5380aa1b845f98a1618c94eeef \ + --hash=sha256:1f447ea5429b54f9582d4b955f5f1985f278ce5cf169f72eea8afd9502973dd5 \ + --hash=sha256:23eed6d7b1a3336ad92d8e39d4bfe09073c31bfe502f20ca5116b2a334f8ec02 \ + --hash=sha256:25f32acefac14ef7bd53e4218fe93b804ef6f6b92ffdb4322bb6d49d94cad2bc \ + --hash=sha256:2c74524e179f2ad6d2a4f7caf70e2d96639c0954c943ad601a9e146c76408ed7 \ + --hash=sha256:303bf1edce6ced16bf67a18a1cf8339d0db79577eec5d9a6d4a80f0fb10aa2da \ + --hash=sha256:3331bece23c9ee066e0fb3f96c61322b9e0f54d775fccefff4c38ca488de283a \ + --hash=sha256:3e9bdd30efde2b9ccfa9cb5768ba04fe71b018a25ea093379c857c9dad262c40 \ + --hash=sha256:411007c0d88188d9f621b11d252cce90c4a2d1a49db6c068e3c16422f306eab8 \ + --hash=sha256:42871176e7896d5d45138f6d28751053c711ed4d48d8e30b498da155af39aebd \ + --hash=sha256:46f409a2d60f634fe550f7133ed30ad5321ae2e6630f13657fb9479506b00601 \ + --hash=sha256:48628bd53a426c9eb9bc066a923acaa0878d1e86129fd5359aee99285f4eed9c \ + --hash=sha256:48d6ed886b343d11493129e019da91d4039826794a3e3027321c56d9e71505be \ + --hash=sha256:4930be26af26ac545c3dffb662521d4e6268352866956672231887d18f0eaab2 \ + --hash=sha256:4aec80cde9197340bc353d2768e2a75f5f60bacda2bab72ab1dc499589b3878c \ + --hash=sha256:4c28a9144688aef80d6ea666c809b4b0e50010a2aca784c97f5e6bf143d9f129 \ + --hash=sha256:4d2d1edbca80b510443f51afd8496be95529db04a509bc8faee49c7b0fb6d2cc \ + --hash=sha256:4dd9a263e845a72eacb60d12401e37c616438ea2e5442885f65082c276dfb2b2 \ + --hash=sha256:4f1026bc732b6a7f96369f7bfe1a4f2290fb34dce00d8644bc3036fb351a4ca1 \ + --hash=sha256:4fb960a632a49f2f089d522f70496640fdf1218f1243889da3822e0a9f5f3ba7 \ + --hash=sha256:50670615eaf97227d5dc60de2dc99fb134a7130d310d783314e7724bf163f75d \ + --hash=sha256:50baa9c1c47efcaef189f31e3d00d697c6d4afda5c3cde0302d063492ff9b477 \ + --hash=sha256:53ace1c1fd5a74ef662f844a0413446c0629d151055340e9893da958a374f70d \ + --hash=sha256:5515edd2a6d1a5a70bfcdee23b42ec33425e405c5b351478ab7dc9347228f96e \ + --hash=sha256:56dc1f1ebccc656d1b3ed288f11e27172a01503fc016bcabdcbc0978b19352b7 \ + --hash=sha256:578695735c5a3f51569810dfebd05dd6f888147a34f0f98d4bb27e92b76e05c2 \ + --hash=sha256:57aba1bbdf450b726d58b2aea5fe47c7875f5afb2c4a23784ed78f19a0462574 \ + --hash=sha256:57d6ba0ca2b0c462f339640d22882acc711de224d769edf29962b09f77129cbf \ + --hash=sha256:5c245b783db29c4e4fbbbfc9c5a78be496c9fea25517f90606aa1f6b2b3d5f7b \ + --hash=sha256:5c31c7462abdf8f2ac0577d9f05279727e698f97ecbb02f17939ea99ae8daa98 \ + --hash=sha256:64f479d719dc9f4c813ad9bb6b28f8390360660b73b2e4beb4cb0ae7104f1c12 \ + --hash=sha256:65299ea57d82fb91c7f019300d24050c4ddeb7c5a190e076b5f48a2b43d19c42 \ + --hash=sha256:6689a3d7fd13dc687e9102a27e98ef33730ac4fe37795d5036d18b4d527abd35 \ + --hash=sha256:690dafd0b187ed38583a648076865d8c229661ed20e48f2335d68e2cf7dc829d \ + --hash=sha256:6fc3c450eaa0b56f815c7b62f2b7fba7266c4779adcf1cece9e6deb1de7305ce \ + --hash=sha256:704f61ba8c1283c71b16135caf697557f5ecf3e74d9e453233e4771d68a1f42d \ + --hash=sha256:71c52db65e4b56b8ddc5bb89fb2e66c558ed9d1a74a45ceb7dcb20c191c3df2f \ + --hash=sha256:71d66ee82e7417828af6ecd7db817913cb0cf9d4e61aa0ac1fde0583d84358db \ + --hash=sha256:7d298a1bd60c067ea75d9f684f5f3992c9d6766fadbc0bcedd39750bf344c2f4 \ + --hash=sha256:8b77946fd508cbf0fccd8e400a7f71d4ac0e1595812e66025bac475a8e811694 \ + --hash=sha256:8d7e43bd40f65f7d97ad8ef5c9b1778943d02f04febef12def25f7583d19baac \ + --hash=sha256:8df133a2ea5e74eef5e8fc6f19b9e085f758768a16e9877a60aec455ed2609b2 \ + --hash=sha256:8ed74706b26ad100433da4b9d807eae371efaa266ffc3e9191ea436087a9d6a7 \ + --hash=sha256:92af161ecbdb2883c4593d5ed4815ea71b31fafd7fd05789b23100d081ecac96 \ + --hash=sha256:97047f0d25cd4bcae81f9ec9dc290ca3e15927c192df17331b53bebe0e3ff96d \ + --hash=sha256:9719fe17307a9e814580af1f5c6e05ca593b12fb7e44fe62450a5384dbf61b4b \ + --hash=sha256:9767e79108424fb6c3edf8f81e6730666a50feb01a328f4a016464a5893f835a \ + --hash=sha256:9a92d3faef50658dd2c5470af249985782bf754c4e18e15afb67d3ab06233f13 \ + --hash=sha256:9bb6ad405121241e99a86efff22d3ef469024ce22875a7ae045896ad23ba2340 \ + --hash=sha256:9e28c51fa0ce5674be9f560c6761c1b441631901993f76700b1b30ca6c8378d6 \ + --hash=sha256:aca086dc5f9ef98c512bac8efea4483eb84abbf926eaeedf7b91479feb092458 \ + --hash=sha256:ae8b9c6deb1e634ba4f1930eb67ef6e6bf6a44b6eb5ad605642b2d6d5ed9ce3c \ + --hash=sha256:b0a545b46b526d418eb91754565ba5b63b1c0b12f9bd2f808c852d9b4b2f9b5c \ + --hash=sha256:b4e4bc18382088514ebde9328da057775055940a1f2e18f6ad2d78aa0f3ec5b9 \ + --hash=sha256:b6420a005548ad52154c8ceab4a1290ff78d757f9e5cbc68f8c77089acd3c432 \ + --hash=sha256:b86164d2cff4d3aaa1f04a14685cbc072efd0b4f99ca5708b2ad1b9b5988a991 \ + --hash=sha256:bb3bb49c7a6ad9d981d734ef7c7193bc349ac338776a0360cc671eaee89bcf69 \ + --hash=sha256:bef4e656f7d98aaa3486d2627e7d2df1157d7e88e7efd43a65aa5dd4714916cf \ + --hash=sha256:c0781a98ff5e6586926293e59480b64ddd46282953203c76ae15dbbbf302e8bb \ + --hash=sha256:c2006f5c8d28dee289f7020f721354362fa304acbaaf9745751ac4006650254b \ + --hash=sha256:c41bfca0bd3532d53d16fd34d20806d5c2b1ace22a2f2e4c0008570bf2c58833 \ + --hash=sha256:cd47b4a0d41d2afa3e58e5bf1f62069255aa2fd6ff5ee41604418ca925911d76 \ + --hash=sha256:cdb650fc86227eba20de1a29d4b2c1bfe139dc75a0669270033cb2ea3d391b85 \ + --hash=sha256:cef2502e7e8a96fe5ad686d60b49e1ab03e438bd9123987994528febd569868e \ + --hash=sha256:d27be7405547d1f958b60837dc4c1007da90b8b23f54ba1f8b728c78fdb19d50 \ + --hash=sha256:d37017287a7adb6ab77e1c5bee9bcf9660f90ff445042b790402a654d2ad81d8 \ + --hash=sha256:d3ff32724f98fbbbfa9f49d82852b159e9784d6094983d9a8b7f2ddaebb063d4 \ + --hash=sha256:d73d8ecf8ecf10a3bd007f2192725a34bd62898e8da27eb9d32a58084f93962b \ + --hash=sha256:dd708cf4ee4408cf46a48b108fb9427bfa00b9b85812a9262b5c668af2533ea5 \ + --hash=sha256:e3cd95e10c2610c360154afdc2f1480aea394f4a4f1ea0a5eacce49640c9b190 \ + --hash=sha256:e4da8ca0c0c0aea88fd46be8e44bd49716772358d648cce45fe387f7b92374a7 \ + --hash=sha256:eadfbbbfb41b44034a4c757fd5d70baccd43296fb894dba0295606a7cf3124aa \ + --hash=sha256:ed667f49b11360951e201453fc3967344d0d0263aa415e1619e85ae7fd17b4e0 \ + --hash=sha256:f3df3db1d336b9356dd3112eae5f5c2b8b377f3bc826848567f10bfddfee77e9 \ + --hash=sha256:f6bdac493b949141b733c5345b6ba8f87a226029cbabc7e9e121a413e49441e0 \ + --hash=sha256:fbf521479bcac1e25a663df882c46a641a9bff6b56dc8b0fafaebd2f66fb231b \ + --hash=sha256:fc9b106a1bf918db68619fdcd6d5ad4f972fdd19c01d19bdb6bf63f3589a9ec5 \ + --hash=sha256:fcdd00edfd0a3001e0181eab3e63bd5c74ad3e67152c84f93f13769a40e073a7 \ + --hash=sha256:fe4bda6bd4340caa6e5cf95e73f8fea5c4bfc55763dd42f1b50a94c1b4a2fbd4 + # via sacrebleu +markdown==3.5 \ + --hash=sha256:4afb124395ce5fc34e6d9886dab977fd9ae987fc6e85689f08278cf0c69d4bf3 \ + --hash=sha256:a807eb2e4778d9156c8f07876c6e4d50b5494c5665c4834f67b06459dfd877b3 + # via tensorboard +markupsafe==2.1.3 \ + --hash=sha256:05fb21170423db021895e1ea1e1f3ab3adb85d1c2333cbc2310f2a26bc77272e \ + --hash=sha256:0a4e4a1aff6c7ac4cd55792abf96c915634c2b97e3cc1c7129578aa68ebd754e \ + --hash=sha256:10bbfe99883db80bdbaff2dcf681dfc6533a614f700da1287707e8a5d78a8431 \ + --hash=sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686 \ + --hash=sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c \ + --hash=sha256:1577735524cdad32f9f694208aa75e422adba74f1baee7551620e43a3141f559 \ + --hash=sha256:1b40069d487e7edb2676d3fbdb2b0829ffa2cd63a2ec26c4938b2d34391b4ecc \ + --hash=sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb \ + --hash=sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939 \ + --hash=sha256:282c2cb35b5b673bbcadb33a585408104df04f14b2d9b01d4c345a3b92861c2c \ + --hash=sha256:2c1b19b3aaacc6e57b7e25710ff571c24d6c3613a45e905b1fde04d691b98ee0 \ + --hash=sha256:2ef12179d3a291be237280175b542c07a36e7f60718296278d8593d21ca937d4 \ + --hash=sha256:338ae27d6b8745585f87218a3f23f1512dbf52c26c28e322dbe54bcede54ccb9 \ + --hash=sha256:3c0fae6c3be832a0a0473ac912810b2877c8cb9d76ca48de1ed31e1c68386575 \ + --hash=sha256:3fd4abcb888d15a94f32b75d8fd18ee162ca0c064f35b11134be77050296d6ba \ + --hash=sha256:42de32b22b6b804f42c5d98be4f7e5e977ecdd9ee9b660fda1a3edf03b11792d \ + --hash=sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd \ + --hash=sha256:504b320cd4b7eff6f968eddf81127112db685e81f7e36e75f9f84f0df46041c3 \ + --hash=sha256:525808b8019e36eb524b8c68acdd63a37e75714eac50e988180b169d64480a00 \ + --hash=sha256:56d9f2ecac662ca1611d183feb03a3fa4406469dafe241673d521dd5ae92a155 \ + --hash=sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac \ + --hash=sha256:65c1a9bcdadc6c28eecee2c119465aebff8f7a584dd719facdd9e825ec61ab52 \ + --hash=sha256:68e78619a61ecf91e76aa3e6e8e33fc4894a2bebe93410754bd28fce0a8a4f9f \ + --hash=sha256:69c0f17e9f5a7afdf2cc9fb2d1ce6aabdb3bafb7f38017c0b77862bcec2bbad8 \ + --hash=sha256:6b2b56950d93e41f33b4223ead100ea0fe11f8e6ee5f641eb753ce4b77a7042b \ + --hash=sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007 \ + --hash=sha256:787003c0ddb00500e49a10f2844fac87aa6ce977b90b0feaaf9de23c22508b24 \ + --hash=sha256:7ef3cb2ebbf91e330e3bb937efada0edd9003683db6b57bb108c4001f37a02ea \ + --hash=sha256:8023faf4e01efadfa183e863fefde0046de576c6f14659e8782065bcece22198 \ + --hash=sha256:8758846a7e80910096950b67071243da3e5a20ed2546e6392603c096778d48e0 \ + --hash=sha256:8afafd99945ead6e075b973fefa56379c5b5c53fd8937dad92c662da5d8fd5ee \ + --hash=sha256:8c41976a29d078bb235fea9b2ecd3da465df42a562910f9022f1a03107bd02be \ + --hash=sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2 \ + --hash=sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1 \ + --hash=sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707 \ + --hash=sha256:962f82a3086483f5e5f64dbad880d31038b698494799b097bc59c2edf392fce6 \ + --hash=sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c \ + --hash=sha256:9dcdfd0eaf283af041973bff14a2e143b8bd64e069f4c383416ecd79a81aab58 \ + --hash=sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823 \ + --hash=sha256:aa7bd130efab1c280bed0f45501b7c8795f9fdbeb02e965371bbef3523627779 \ + --hash=sha256:ab4a0df41e7c16a1392727727e7998a467472d0ad65f3ad5e6e765015df08636 \ + --hash=sha256:ad9e82fb8f09ade1c3e1b996a6337afac2b8b9e365f926f5a61aacc71adc5b3c \ + --hash=sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad \ + --hash=sha256:b076b6226fb84157e3f7c971a47ff3a679d837cf338547532ab866c57930dbee \ + --hash=sha256:b7ff0f54cb4ff66dd38bebd335a38e2c22c41a8ee45aa608efc890ac3e3931bc \ + --hash=sha256:bfce63a9e7834b12b87c64d6b155fdd9b3b96191b6bd334bf37db7ff1fe457f2 \ + --hash=sha256:c011a4149cfbcf9f03994ec2edffcb8b1dc2d2aede7ca243746df97a5d41ce48 \ + --hash=sha256:c9c804664ebe8f83a211cace637506669e7890fec1b4195b505c214e50dd4eb7 \ + --hash=sha256:ca379055a47383d02a5400cb0d110cef0a776fc644cda797db0c5696cfd7e18e \ + --hash=sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b \ + --hash=sha256:cd0f502fe016460680cd20aaa5a76d241d6f35a1c3350c474bac1273803893fa \ + --hash=sha256:ceb01949af7121f9fc39f7d27f91be8546f3fb112c608bc4029aef0bab86a2a5 \ + --hash=sha256:d080e0a5eb2529460b30190fcfcc4199bd7f827663f858a226a81bc27beaa97e \ + --hash=sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb \ + --hash=sha256:df0be2b576a7abbf737b1575f048c23fb1d769f267ec4358296f31c2479db8f9 \ + --hash=sha256:e09031c87a1e51556fdcb46e5bd4f59dfb743061cf93c4d6831bf894f125eb57 \ + --hash=sha256:e4dd52d80b8c83fdce44e12478ad2e85c64ea965e75d66dbeafb0a3e77308fcc \ + --hash=sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc \ + --hash=sha256:fec21693218efe39aa7f8599346e90c705afa52c5b31ae019b2e57e8f6542bb2 \ + --hash=sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11 + # via werkzeug +matplotlib==3.7.0 \ + --hash=sha256:01681566e95b9423021b49dea6a2395c16fa054604eacb87f0f4c439750f9114 \ + --hash=sha256:03eb2c8ff8d85da679b71e14c7c95d16d014c48e0c0bfa14db85f6cdc5c92aad \ + --hash=sha256:092e6abc80cdf8a95f7d1813e16c0e99ceda8d5b195a3ab859c680f3487b80a2 \ + --hash=sha256:0a776462a4a63c0bfc9df106c15a0897aa2dbab6795c693aa366e8e283958854 \ + --hash=sha256:0dfd4a0cbd151f6439e6d7f8dca5292839ca311e7e650596d073774847ca2e4f \ + --hash=sha256:111ef351f28fd823ed7177632070a6badd6f475607122bc9002a526f2502a0b5 \ + --hash=sha256:21269450243d6928da81a9bed201f0909432a74e7d0d65db5545b9fa8a0d0223 \ + --hash=sha256:21a8aeac39b4a795e697265d800ce52ab59bdeb6bb23082e2d971f3041074f02 \ + --hash=sha256:21bd4033c40b95abd5b8453f036ed5aa70856e56ecbd887705c37dce007a4c21 \ + --hash=sha256:3493b48e56468c39bd9c1532566dff3b8062952721b7521e1f394eb6791495f4 \ + --hash=sha256:3a10428d4f8d1a478ceabd652e61a175b2fdeed4175ab48da4a7b8deb561e3fa \ + --hash=sha256:3d1e52365d8d5af699f04581ca191112e1d1220a9ce4386b57d807124d8b55e6 \ + --hash=sha256:3da8b9618188346239e51f1ea6c0f8f05c6e218cfcc30b399dd7dd7f52e8bceb \ + --hash=sha256:4497d88c559b76da320b7759d64db442178beeea06a52dc0c629086982082dcd \ + --hash=sha256:46ca923e980f76d34c1c633343a72bb042d6ba690ecc649aababf5317997171d \ + --hash=sha256:4f640534ec2760e270801056bc0d8a10777c48b30966eef78a7c35d8590915ba \ + --hash=sha256:51fb664c37714cbaac69c16d6b3719f517a13c96c3f76f4caadd5a0aa7ed0329 \ + --hash=sha256:56b7b79488209041a9bf7ddc34f1b069274489ce69e34dc63ae241d0d6b4b736 \ + --hash=sha256:691ef1f15360e439886186d0db77b5345b24da12cbc4fc57b26c4826db4d6cab \ + --hash=sha256:71b751d06b2ed1fd017de512d7439c0259822864ea16731522b251a27c0b2ede \ + --hash=sha256:7d0dcd1a0bf8d56551e8617d6dc3881d8a1c7fb37d14e5ec12cbb293f3e6170a \ + --hash=sha256:827e78239292e561cfb70abf356a9d7eaf5bf6a85c97877f254009f20b892f89 \ + --hash=sha256:8665855f3919c80551f377bc16df618ceabf3ef65270bc14b60302dce88ca9ab \ + --hash=sha256:8f6efd313430d7ef70a38a3276281cb2e8646b3a22b3b21eb227da20e15e6813 \ + --hash=sha256:9d85355c48ef8b9994293eb7c00f44aa8a43cad7a297fbf0770a25cdb2244b91 \ + --hash=sha256:a06a6c9822e80f323549c6bc9da96d4f233178212ad9a5f4ab87fd153077a507 \ + --hash=sha256:b51ab8a5d5d3bbd4527af633a638325f492e09e45e78afdf816ef55217a09664 \ + --hash=sha256:c0592ba57217c22987b7322df10f75ef95bc44dce781692b4b7524085de66019 \ + --hash=sha256:c5465735eaaafd1cfaec3fed60aee776aeb3fd3992aa2e49f4635339c931d443 \ + --hash=sha256:c849aa94ff2a70fb71f318f48a61076d1205c6013b9d3885ade7f992093ac434 \ + --hash=sha256:c869b646489c6a94375714032e5cec08e3aa8d3f7d4e8ef2b0fb50a52b317ce6 \ + --hash=sha256:cb52aa97b92acdee090edfb65d1cb84ea60ab38e871ba8321a10bbcebc2a3540 \ + --hash=sha256:cf119eee4e57389fba5ac8b816934e95c256535e55f0b21628b4205737d1de85 \ + --hash=sha256:cf6346644e8fe234dc847e6232145dac199a650d3d8025b3ef65107221584ba4 \ + --hash=sha256:de20eb1247725a2f889173d391a6d9e7e0f2540feda24030748283108b0478ec \ + --hash=sha256:eb2e76cd429058d8954121c334dddfcd11a6186c6975bca61f3f248c99031b05 \ + --hash=sha256:f336e7014889c38c59029ebacc35c59236a852e4b23836708cfd3f43d1eaeed5 \ + --hash=sha256:f4ddac5f59e78d04b20469bc43853a8e619bb6505c7eac8ffb343ff2c516d72f \ + --hash=sha256:f910d924da8b9fb066b5beae0b85e34ed1b6293014892baadcf2a51da1c65807 \ + --hash=sha256:f91d35b3ef51d29d9c661069b9e4ba431ce283ffc533b981506889e144b5b40e \ + --hash=sha256:fb0304c1cd802e9a25743414c887e8a7cd51d96c9ec96d388625d2cd1c137ae3 + # via + # -r requirements.in + # flax +mesh-tensorflow[transformer]==0.1.21 \ + --hash=sha256:747d46696ad260ae59a566fc1726749b99080973b07aae098be5714931a3ad80 \ + --hash=sha256:f674afcd260cc6c506b00f623aeb53a2a72e2afa1a318c95b936d961777d8d94 + # via t5 +ml-collections==0.1.1 \ + --hash=sha256:3fefcc72ec433aa1e5d32307a3e474bbb67f405be814ea52a2166bfc9dbe68cc + # via clu +msgpack==1.0.7 \ + --hash=sha256:04ad6069c86e531682f9e1e71b71c1c3937d6014a7c3e9edd2aa81ad58842862 \ + --hash=sha256:0bfdd914e55e0d2c9e1526de210f6fe8ffe9705f2b1dfcc4aecc92a4cb4b533d \ + --hash=sha256:1dc93e8e4653bdb5910aed79f11e165c85732067614f180f70534f056da97db3 \ + --hash=sha256:1e2d69948e4132813b8d1131f29f9101bc2c915f26089a6d632001a5c1349672 \ + --hash=sha256:235a31ec7db685f5c82233bddf9858748b89b8119bf4538d514536c485c15fe0 \ + --hash=sha256:27dcd6f46a21c18fa5e5deed92a43d4554e3df8d8ca5a47bf0615d6a5f39dbc9 \ + --hash=sha256:28efb066cde83c479dfe5a48141a53bc7e5f13f785b92ddde336c716663039ee \ + --hash=sha256:3476fae43db72bd11f29a5147ae2f3cb22e2f1a91d575ef130d2bf49afd21c46 \ + --hash=sha256:36e17c4592231a7dbd2ed09027823ab295d2791b3b1efb2aee874b10548b7524 \ + --hash=sha256:384d779f0d6f1b110eae74cb0659d9aa6ff35aaf547b3955abf2ab4c901c4819 \ + --hash=sha256:38949d30b11ae5f95c3c91917ee7a6b239f5ec276f271f28638dec9156f82cfc \ + --hash=sha256:3967e4ad1aa9da62fd53e346ed17d7b2e922cba5ab93bdd46febcac39be636fc \ + --hash=sha256:3e7bf4442b310ff154b7bb9d81eb2c016b7d597e364f97d72b1acc3817a0fdc1 \ + --hash=sha256:3f0c8c6dfa6605ab8ff0611995ee30d4f9fcff89966cf562733b4008a3d60d82 \ + --hash=sha256:484ae3240666ad34cfa31eea7b8c6cd2f1fdaae21d73ce2974211df099a95d81 \ + --hash=sha256:4a7b4f35de6a304b5533c238bee86b670b75b03d31b7797929caa7a624b5dda6 \ + --hash=sha256:4cb14ce54d9b857be9591ac364cb08dc2d6a5c4318c1182cb1d02274029d590d \ + --hash=sha256:4e71bc4416de195d6e9b4ee93ad3f2f6b2ce11d042b4d7a7ee00bbe0358bd0c2 \ + --hash=sha256:52700dc63a4676669b341ba33520f4d6e43d3ca58d422e22ba66d1736b0a6e4c \ + --hash=sha256:572efc93db7a4d27e404501975ca6d2d9775705c2d922390d878fcf768d92c87 \ + --hash=sha256:576eb384292b139821c41995523654ad82d1916da6a60cff129c715a6223ea84 \ + --hash=sha256:5b0bf0effb196ed76b7ad883848143427a73c355ae8e569fa538365064188b8e \ + --hash=sha256:5b6ccc0c85916998d788b295765ea0e9cb9aac7e4a8ed71d12e7d8ac31c23c95 \ + --hash=sha256:5ed82f5a7af3697b1c4786053736f24a0efd0a1b8a130d4c7bfee4b9ded0f08f \ + --hash=sha256:6d4c80667de2e36970ebf74f42d1088cc9ee7ef5f4e8c35eee1b40eafd33ca5b \ + --hash=sha256:730076207cb816138cf1af7f7237b208340a2c5e749707457d70705715c93b93 \ + --hash=sha256:7687e22a31e976a0e7fc99c2f4d11ca45eff652a81eb8c8085e9609298916dcf \ + --hash=sha256:822ea70dc4018c7e6223f13affd1c5c30c0f5c12ac1f96cd8e9949acddb48a61 \ + --hash=sha256:84b0daf226913133f899ea9b30618722d45feffa67e4fe867b0b5ae83a34060c \ + --hash=sha256:85765fdf4b27eb5086f05ac0491090fc76f4f2b28e09d9350c31aac25a5aaff8 \ + --hash=sha256:8dd178c4c80706546702c59529ffc005681bd6dc2ea234c450661b205445a34d \ + --hash=sha256:8f5b234f567cf76ee489502ceb7165c2a5cecec081db2b37e35332b537f8157c \ + --hash=sha256:98bbd754a422a0b123c66a4c341de0474cad4a5c10c164ceed6ea090f3563db4 \ + --hash=sha256:993584fc821c58d5993521bfdcd31a4adf025c7d745bbd4d12ccfecf695af5ba \ + --hash=sha256:a40821a89dc373d6427e2b44b572efc36a2778d3f543299e2f24eb1a5de65415 \ + --hash=sha256:b291f0ee7961a597cbbcc77709374087fa2a9afe7bdb6a40dbbd9b127e79afee \ + --hash=sha256:b573a43ef7c368ba4ea06050a957c2a7550f729c31f11dd616d2ac4aba99888d \ + --hash=sha256:b610ff0f24e9f11c9ae653c67ff8cc03c075131401b3e5ef4b82570d1728f8a9 \ + --hash=sha256:bdf38ba2d393c7911ae989c3bbba510ebbcdf4ecbdbfec36272abe350c454075 \ + --hash=sha256:bfef2bb6ef068827bbd021017a107194956918ab43ce4d6dc945ffa13efbc25f \ + --hash=sha256:cab3db8bab4b7e635c1c97270d7a4b2a90c070b33cbc00c99ef3f9be03d3e1f7 \ + --hash=sha256:cb70766519500281815dfd7a87d3a178acf7ce95390544b8c90587d76b227681 \ + --hash=sha256:cca1b62fe70d761a282496b96a5e51c44c213e410a964bdffe0928e611368329 \ + --hash=sha256:ccf9a39706b604d884d2cb1e27fe973bc55f2890c52f38df742bc1d79ab9f5e1 \ + --hash=sha256:dc43f1ec66eb8440567186ae2f8c447d91e0372d793dfe8c222aec857b81a8cf \ + --hash=sha256:dd632777ff3beaaf629f1ab4396caf7ba0bdd075d948a69460d13d44357aca4c \ + --hash=sha256:e45ae4927759289c30ccba8d9fdce62bb414977ba158286b5ddaf8df2cddb5c5 \ + --hash=sha256:e50ebce52f41370707f1e21a59514e3375e3edd6e1832f5e5235237db933c98b \ + --hash=sha256:ebbbba226f0a108a7366bf4b59bf0f30a12fd5e75100c630267d94d7f0ad20e5 \ + --hash=sha256:ec79ff6159dffcc30853b2ad612ed572af86c92b5168aa3fc01a67b0fa40665e \ + --hash=sha256:f0936e08e0003f66bfd97e74ee530427707297b0d0361247e9b4f59ab78ddc8b \ + --hash=sha256:f26a07a6e877c76a88e3cecac8531908d980d3d5067ff69213653649ec0f60ad \ + --hash=sha256:f64e376cd20d3f030190e8c32e1c64582eba56ac6dc7d5b0b49a9d44021b52fd \ + --hash=sha256:f6ffbc252eb0d229aeb2f9ad051200668fc3a9aaa8994e49f0cb2ffe2b7867e7 \ + --hash=sha256:f9a7c509542db4eceed3dcf21ee5267ab565a83555c9b88a8109dcecc4709002 \ + --hash=sha256:ff1d0899f104f3921d94579a5638847f783c9b04f2d5f229392ca77fba5b82fc + # via flax +nltk==3.8.1 \ + --hash=sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3 \ + --hash=sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5 + # via + # rouge-score + # t5 +numpy==1.23.5 \ + --hash=sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d \ + --hash=sha256:06005a2ef6014e9956c09ba07654f9837d9e26696a0470e42beedadb78c11b07 \ + --hash=sha256:09b7847f7e83ca37c6e627682f145856de331049013853f344f37b0c9690e3df \ + --hash=sha256:0aaee12d8883552fadfc41e96b4c82ee7d794949e2a7c3b3a7201e968c7ecab9 \ + --hash=sha256:0cbe9848fad08baf71de1a39e12d1b6310f1d5b2d0ea4de051058e6e1076852d \ + --hash=sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a \ + --hash=sha256:33161613d2269025873025b33e879825ec7b1d831317e68f4f2f0f84ed14c719 \ + --hash=sha256:5039f55555e1eab31124a5768898c9e22c25a65c1e0037f4d7c495a45778c9f2 \ + --hash=sha256:522e26bbf6377e4d76403826ed689c295b0b238f46c28a7251ab94716da0b280 \ + --hash=sha256:56e454c7833e94ec9769fa0f86e6ff8e42ee38ce0ce1fa4cbb747ea7e06d56aa \ + --hash=sha256:58f545efd1108e647604a1b5aa809591ccd2540f468a880bedb97247e72db387 \ + --hash=sha256:5e05b1c973a9f858c74367553e236f287e749465f773328c8ef31abe18f691e1 \ + --hash=sha256:7903ba8ab592b82014713c491f6c5d3a1cde5b4a3bf116404e08f5b52f6daf43 \ + --hash=sha256:8969bfd28e85c81f3f94eb4a66bc2cf1dbdc5c18efc320af34bffc54d6b1e38f \ + --hash=sha256:92c8c1e89a1f5028a4c6d9e3ccbe311b6ba53694811269b992c0b224269e2398 \ + --hash=sha256:9c88793f78fca17da0145455f0d7826bcb9f37da4764af27ac945488116efe63 \ + --hash=sha256:a7ac231a08bb37f852849bbb387a20a57574a97cfc7b6cabb488a4fc8be176de \ + --hash=sha256:abdde9f795cf292fb9651ed48185503a2ff29be87770c3b8e2a14b0cd7aa16f8 \ + --hash=sha256:af1da88f6bc3d2338ebbf0e22fe487821ea4d8e89053e25fa59d1d79786e7481 \ + --hash=sha256:b2a9ab7c279c91974f756c84c365a669a887efa287365a8e2c418f8b3ba73fb0 \ + --hash=sha256:bf837dc63ba5c06dc8797c398db1e223a466c7ece27a1f7b5232ba3466aafe3d \ + --hash=sha256:ca51fcfcc5f9354c45f400059e88bc09215fb71a48d3768fb80e357f3b457e1e \ + --hash=sha256:ce571367b6dfe60af04e04a1834ca2dc5f46004ac1cc756fb95319f64c095a96 \ + --hash=sha256:d208a0f8729f3fb790ed18a003f3a57895b989b40ea4dce4717e9cf4af62c6bb \ + --hash=sha256:dbee87b469018961d1ad79b1a5d50c0ae850000b639bcb1b694e9981083243b6 \ + --hash=sha256:e9f4c4e51567b616be64e05d517c79a8a22f3606499941d97bb76f2ca59f982d \ + --hash=sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a \ + --hash=sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135 + # via + # -r requirements.in + # chex + # clu + # contourpy + # etils + # flax + # h5py + # jax + # jaxlib + # matplotlib + # opt-einsum + # optax + # pandas + # rouge-score + # sacrebleu + # scikit-learn + # scipy + # seqio + # seqio-nightly + # t5 + # tensorboard + # tensorflow + # tensorflow-datasets + # tensorflow-hub + # tensorstore + # tfds-nightly + # transformers +oauthlib==3.2.2 \ + --hash=sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca \ + --hash=sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918 + # via requests-oauthlib +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via + # jax + # tensorflow +optax==0.1.7 \ + --hash=sha256:2b85115f2ae7adafe5fd9abf4b275e53057765361511c8ccc868e70158458494 \ + --hash=sha256:6a5a848bc5e55e619b187c749fdddc4a5443ea14be85cc769f995779865c110d + # via + # -r requirements.in + # flax +packaging==23.2 \ + --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ + --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 + # via + # clu + # huggingface-hub + # matplotlib + # seqio + # seqio-nightly + # tensorflow + # transformers +pandas==2.1.1 \ + --hash=sha256:02304e11582c5d090e5a52aec726f31fe3f42895d6bfc1f28738f9b64b6f0614 \ + --hash=sha256:0489b0e6aa3d907e909aef92975edae89b1ee1654db5eafb9be633b0124abe97 \ + --hash=sha256:05674536bd477af36aa2effd4ec8f71b92234ce0cc174de34fd21e2ee99adbc2 \ + --hash=sha256:25e8474a8eb258e391e30c288eecec565bfed3e026f312b0cbd709a63906b6f8 \ + --hash=sha256:29deb61de5a8a93bdd033df328441a79fcf8dd3c12d5ed0b41a395eef9cd76f0 \ + --hash=sha256:366da7b0e540d1b908886d4feb3d951f2f1e572e655c1160f5fde28ad4abb750 \ + --hash=sha256:3bcad1e6fb34b727b016775bea407311f7721db87e5b409e6542f4546a4951ea \ + --hash=sha256:4c3f32fd7c4dccd035f71734df39231ac1a6ff95e8bdab8d891167197b7018d2 \ + --hash=sha256:4cdb0fab0400c2cb46dafcf1a0fe084c8bb2480a1fa8d81e19d15e12e6d4ded2 \ + --hash=sha256:4f99bebf19b7e03cf80a4e770a3e65eee9dd4e2679039f542d7c1ace7b7b1daa \ + --hash=sha256:58d997dbee0d4b64f3cb881a24f918b5f25dd64ddf31f467bb9b67ae4c63a1e4 \ + --hash=sha256:75ce97667d06d69396d72be074f0556698c7f662029322027c226fd7a26965cb \ + --hash=sha256:84e7e910096416adec68075dc87b986ff202920fb8704e6d9c8c9897fe7332d6 \ + --hash=sha256:9e2959720b70e106bb1d8b6eadd8ecd7c8e99ccdbe03ee03260877184bb2877d \ + --hash=sha256:9e50e72b667415a816ac27dfcfe686dc5a0b02202e06196b943d54c4f9c7693e \ + --hash=sha256:a0dbfea0dd3901ad4ce2306575c54348d98499c95be01b8d885a2737fe4d7a98 \ + --hash=sha256:b407381258a667df49d58a1b637be33e514b07f9285feb27769cedb3ab3d0b3a \ + --hash=sha256:b8bd1685556f3374520466998929bade3076aeae77c3e67ada5ed2b90b4de7f0 \ + --hash=sha256:c1f84c144dee086fe4f04a472b5cd51e680f061adf75c1ae4fc3a9275560f8f4 \ + --hash=sha256:c747793c4e9dcece7bb20156179529898abf505fe32cb40c4052107a3c620b49 \ + --hash=sha256:cc1ab6a25da197f03ebe6d8fa17273126120874386b4ac11c1d687df288542dd \ + --hash=sha256:dc3657869c7902810f32bd072f0740487f9e030c1a3ab03e0af093db35a9d14e \ + --hash=sha256:f5ec7740f9ccb90aec64edd71434711f58ee0ea7f5ed4ac48be11cfa9abf7317 \ + --hash=sha256:fecb198dc389429be557cde50a2d46da8434a17fe37d7d41ff102e3987fd947b \ + --hash=sha256:ffa8f0966de2c22de408d0e322db2faed6f6e74265aa0856f3824813cf124363 + # via t5 +pillow==10.0.1 \ + --hash=sha256:0462b1496505a3462d0f35dc1c4d7b54069747d65d00ef48e736acda2c8cbdff \ + --hash=sha256:186f7e04248103482ea6354af6d5bcedb62941ee08f7f788a1c7707bc720c66f \ + --hash=sha256:19e9adb3f22d4c416e7cd79b01375b17159d6990003633ff1d8377e21b7f1b21 \ + --hash=sha256:28444cb6ad49726127d6b340217f0627abc8732f1194fd5352dec5e6a0105635 \ + --hash=sha256:2872f2d7846cf39b3dbff64bc1104cc48c76145854256451d33c5faa55c04d1a \ + --hash=sha256:2cc6b86ece42a11f16f55fe8903595eff2b25e0358dec635d0a701ac9586588f \ + --hash=sha256:2d7e91b4379f7a76b31c2dda84ab9e20c6220488e50f7822e59dac36b0cd92b1 \ + --hash=sha256:2fa6dd2661838c66f1a5473f3b49ab610c98a128fc08afbe81b91a1f0bf8c51d \ + --hash=sha256:32bec7423cdf25c9038fef614a853c9d25c07590e1a870ed471f47fb80b244db \ + --hash=sha256:3855447d98cced8670aaa63683808df905e956f00348732448b5a6df67ee5849 \ + --hash=sha256:3a04359f308ebee571a3127fdb1bd01f88ba6f6fb6d087f8dd2e0d9bff43f2a7 \ + --hash=sha256:3a0d3e54ab1df9df51b914b2233cf779a5a10dfd1ce339d0421748232cea9876 \ + --hash=sha256:44e7e4587392953e5e251190a964675f61e4dae88d1e6edbe9f36d6243547ff3 \ + --hash=sha256:459307cacdd4138edee3875bbe22a2492519e060660eaf378ba3b405d1c66317 \ + --hash=sha256:4ce90f8a24e1c15465048959f1e94309dfef93af272633e8f37361b824532e91 \ + --hash=sha256:50bd5f1ebafe9362ad622072a1d2f5850ecfa44303531ff14353a4059113b12d \ + --hash=sha256:522ff4ac3aaf839242c6f4e5b406634bfea002469656ae8358644fc6c4856a3b \ + --hash=sha256:552912dbca585b74d75279a7570dd29fa43b6d93594abb494ebb31ac19ace6bd \ + --hash=sha256:5d6c9049c6274c1bb565021367431ad04481ebb54872edecfcd6088d27edd6ed \ + --hash=sha256:697a06bdcedd473b35e50a7e7506b1d8ceb832dc238a336bd6f4f5aa91a4b500 \ + --hash=sha256:71671503e3015da1b50bd18951e2f9daf5b6ffe36d16f1eb2c45711a301521a7 \ + --hash=sha256:723bd25051454cea9990203405fa6b74e043ea76d4968166dfd2569b0210886a \ + --hash=sha256:764d2c0daf9c4d40ad12fbc0abd5da3af7f8aa11daf87e4fa1b834000f4b6b0a \ + --hash=sha256:787bb0169d2385a798888e1122c980c6eff26bf941a8ea79747d35d8f9210ca0 \ + --hash=sha256:7f771e7219ff04b79e231d099c0a28ed83aa82af91fd5fa9fdb28f5b8d5addaf \ + --hash=sha256:847e8d1017c741c735d3cd1883fa7b03ded4f825a6e5fcb9378fd813edee995f \ + --hash=sha256:84efb46e8d881bb06b35d1d541aa87f574b58e87f781cbba8d200daa835b42e1 \ + --hash=sha256:898f1d306298ff40dc1b9ca24824f0488f6f039bc0e25cfb549d3195ffa17088 \ + --hash=sha256:8b451d6ead6e3500b6ce5c7916a43d8d8d25ad74b9102a629baccc0808c54971 \ + --hash=sha256:8f06be50669087250f319b706decf69ca71fdecd829091a37cc89398ca4dc17a \ + --hash=sha256:92a23b0431941a33242b1f0ce6c88a952e09feeea9af4e8be48236a68ffe2205 \ + --hash=sha256:93139acd8109edcdeffd85e3af8ae7d88b258b3a1e13a038f542b79b6d255c54 \ + --hash=sha256:98533fd7fa764e5f85eebe56c8e4094db912ccbe6fbf3a58778d543cadd0db08 \ + --hash=sha256:9f665d1e6474af9f9da5e86c2a3a2d2d6204e04d5af9c06b9d42afa6ebde3f21 \ + --hash=sha256:b059ac2c4c7a97daafa7dc850b43b2d3667def858a4f112d1aa082e5c3d6cf7d \ + --hash=sha256:b1be1c872b9b5fcc229adeadbeb51422a9633abd847c0ff87dc4ef9bb184ae08 \ + --hash=sha256:b7cf63d2c6928b51d35dfdbda6f2c1fddbe51a6bc4a9d4ee6ea0e11670dd981e \ + --hash=sha256:bc2e3069569ea9dbe88d6b8ea38f439a6aad8f6e7a6283a38edf61ddefb3a9bf \ + --hash=sha256:bcf1207e2f2385a576832af02702de104be71301c2696d0012b1b93fe34aaa5b \ + --hash=sha256:ca26ba5767888c84bf5a0c1a32f069e8204ce8c21d00a49c90dabeba00ce0145 \ + --hash=sha256:cbe68deb8580462ca0d9eb56a81912f59eb4542e1ef8f987405e35a0179f4ea2 \ + --hash=sha256:d6caf3cd38449ec3cd8a68b375e0c6fe4b6fd04edb6c9766b55ef84a6e8ddf2d \ + --hash=sha256:d72967b06be9300fed5cfbc8b5bafceec48bf7cdc7dab66b1d2549035287191d \ + --hash=sha256:d889b53ae2f030f756e61a7bff13684dcd77e9af8b10c6048fb2c559d6ed6eaf \ + --hash=sha256:de596695a75496deb3b499c8c4f8e60376e0516e1a774e7bc046f0f48cd620ad \ + --hash=sha256:e6a90167bcca1216606223a05e2cf991bb25b14695c518bc65639463d7db722d \ + --hash=sha256:ed2d9c0704f2dc4fa980b99d565c0c9a543fe5101c25b3d60488b8ba80f0cce1 \ + --hash=sha256:ee7810cf7c83fa227ba9125de6084e5e8b08c59038a7b2c9045ef4dde61663b4 \ + --hash=sha256:f0b4b06da13275bc02adfeb82643c4a6385bd08d26f03068c2796f60d125f6f2 \ + --hash=sha256:f11c9102c56ffb9ca87134bd025a43d2aba3f1155f508eff88f694b33a9c6d19 \ + --hash=sha256:f5bb289bb835f9fe1a1e9300d011eef4d69661bb9b34d5e196e5e82c4cb09b37 \ + --hash=sha256:f6d3d4c905e26354e8f9d82548475c46d8e0889538cb0657aa9c6f0872a37aa4 \ + --hash=sha256:fcb59711009b0168d6ee0bd8fb5eb259c4ab1717b2f538bbf36bacf207ef7a68 \ + --hash=sha256:fd2a5403a75b54661182b75ec6132437a181209b901446ee5724b589af8edef1 + # via matplotlib +portalocker==2.8.2 \ + --hash=sha256:2b035aa7828e46c58e9b31390ee1f169b98e1066ab10b9a6a861fe7e25ee4f33 \ + --hash=sha256:cfb86acc09b9aa7c3b43594e19be1345b9d16af3feb08bf92f23d4dce513a28e + # via sacrebleu +promise==2.3 \ + --hash=sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0 + # via + # tensorflow-datasets + # tfds-nightly +protobuf==3.20.3 \ + --hash=sha256:03038ac1cfbc41aa21f6afcbcd357281d7521b4157926f30ebecc8d4ea59dcb7 \ + --hash=sha256:28545383d61f55b57cf4df63eebd9827754fd2dc25f80c5253f9184235db242c \ + --hash=sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2 \ + --hash=sha256:398a9e0c3eaceb34ec1aee71894ca3299605fa8e761544934378bbc6c97de23b \ + --hash=sha256:44246bab5dd4b7fbd3c0c80b6f16686808fab0e4aca819ade6e8d294a29c7050 \ + --hash=sha256:447d43819997825d4e71bf5769d869b968ce96848b6479397e29fc24c4a5dfe9 \ + --hash=sha256:67a3598f0a2dcbc58d02dd1928544e7d88f764b47d4a286202913f0b2801c2e7 \ + --hash=sha256:74480f79a023f90dc6e18febbf7b8bac7508420f2006fabd512013c0c238f454 \ + --hash=sha256:819559cafa1a373b7096a482b504ae8a857c89593cf3a25af743ac9ecbd23480 \ + --hash=sha256:899dc660cd599d7352d6f10d83c95df430a38b410c1b66b407a6b29265d66469 \ + --hash=sha256:8c0c984a1b8fef4086329ff8dd19ac77576b384079247c770f29cc8ce3afa06c \ + --hash=sha256:9aae4406ea63d825636cc11ffb34ad3379335803216ee3a856787bcf5ccc751e \ + --hash=sha256:a7ca6d488aa8ff7f329d4c545b2dbad8ac31464f1d8b1c87ad1346717731e4db \ + --hash=sha256:b6cc7ba72a8850621bfec987cb72623e703b7fe2b9127a161ce61e61558ad905 \ + --hash=sha256:bf01b5720be110540be4286e791db73f84a2b721072a3711efff6c324cdf074b \ + --hash=sha256:c02ce36ec760252242a33967d51c289fd0e1c0e6e5cc9397e2279177716add86 \ + --hash=sha256:d9e4432ff660d67d775c66ac42a67cf2453c27cb4d738fc22cb53b5d84c135d4 \ + --hash=sha256:daa564862dd0d39c00f8086f88700fdbe8bc717e993a21e90711acfed02f2402 \ + --hash=sha256:de78575669dddf6099a8a0f46a27e82a1783c557ccc38ee620ed8cc96d3be7d7 \ + --hash=sha256:e64857f395505ebf3d2569935506ae0dfc4a15cb80dc25261176c784662cdcc4 \ + --hash=sha256:f4bd856d702e5b0d96a00ec6b307b0f51c1982c2bf9c0052cf9019e9a544ba99 \ + --hash=sha256:f4c42102bc82a51108e449cbb32b19b180022941c727bac0cfd50170341f16ee + # via + # googleapis-common-protos + # seqio + # seqio-nightly + # tensorboard + # tensorflow + # tensorflow-datasets + # tensorflow-hub + # tensorflow-metadata + # tfds-nightly +psutil==5.9.5 \ + --hash=sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d \ + --hash=sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217 \ + --hash=sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4 \ + --hash=sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c \ + --hash=sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f \ + --hash=sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da \ + --hash=sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4 \ + --hash=sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42 \ + --hash=sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5 \ + --hash=sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4 \ + --hash=sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9 \ + --hash=sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f \ + --hash=sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30 \ + --hash=sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48 + # via + # tensorflow-datasets + # tfds-nightly +pyasn1==0.5.0 \ + --hash=sha256:87a2121042a1ac9358cabcaf1d07680ff97ee6404333bacca15f76aa8ad01a57 \ + --hash=sha256:97b7290ca68e62a832558ec3976f15cbf911bf5d7c7039d8b861c2a0ece69fde + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 \ + --hash=sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c \ + --hash=sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d + # via google-auth +pyglove==0.4.3 \ + --hash=sha256:836ca495fcbf87a20b49b4638502b67816d0ea22e60f1cc28d128f216b21f172 \ + --hash=sha256:b9b5f8b24b8c7fdccdb4424ab67b735732b854cc778315ef305a1836a8219a25 + # via + # seqio + # seqio-nightly +pygments==2.16.1 \ + --hash=sha256:13fc09fa63bc8d8671a6d247e1eb303c4b343eaee81d861f3404db2935653692 \ + --hash=sha256:1daff0494820c69bc8941e407aa20f577374ee88364ee10a98fdbe0aece96e29 + # via rich +pyparsing==3.1.1 \ + --hash=sha256:32c7c0b711493c72ff18a981d24f28aaf9c1fb7ed5e9667c9e84e3db623bdbfb \ + --hash=sha256:ede28a1a32462f5a9705e07aea48001a08f7cf81a021585011deba701581a0db + # via matplotlib +pysocks==1.7.1 \ + --hash=sha256:08e69f092cc6dbe92a0fdd16eeb9b9ffbc13cadfe5ca4c7bd92ffb078b293299 \ + --hash=sha256:2725bd0a9925919b9b51739eea5f9e2bae91e83288108a9ad338b2e3a4435ee5 \ + --hash=sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0 + # via requests +python-dateutil==2.8.2 \ + --hash=sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86 \ + --hash=sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 + # via + # matplotlib + # pandas +pytz==2023.3.post1 \ + --hash=sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b \ + --hash=sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7 + # via pandas +pyyaml==6.0.1 \ + --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ + --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ + --hash=sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df \ + --hash=sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741 \ + --hash=sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206 \ + --hash=sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27 \ + --hash=sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595 \ + --hash=sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62 \ + --hash=sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98 \ + --hash=sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696 \ + --hash=sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290 \ + --hash=sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9 \ + --hash=sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d \ + --hash=sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6 \ + --hash=sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867 \ + --hash=sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47 \ + --hash=sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486 \ + --hash=sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6 \ + --hash=sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3 \ + --hash=sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007 \ + --hash=sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938 \ + --hash=sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0 \ + --hash=sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c \ + --hash=sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735 \ + --hash=sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d \ + --hash=sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28 \ + --hash=sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4 \ + --hash=sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba \ + --hash=sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8 \ + --hash=sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5 \ + --hash=sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd \ + --hash=sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3 \ + --hash=sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0 \ + --hash=sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515 \ + --hash=sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c \ + --hash=sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c \ + --hash=sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924 \ + --hash=sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34 \ + --hash=sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43 \ + --hash=sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859 \ + --hash=sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673 \ + --hash=sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54 \ + --hash=sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a \ + --hash=sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b \ + --hash=sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab \ + --hash=sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa \ + --hash=sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c \ + --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ + --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ + --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f + # via + # flax + # huggingface-hub + # ml-collections + # transformers +regex==2023.10.3 \ + --hash=sha256:00ba3c9818e33f1fa974693fb55d24cdc8ebafcb2e4207680669d8f8d7cca79a \ + --hash=sha256:00e871d83a45eee2f8688d7e6849609c2ca2a04a6d48fba3dff4deef35d14f07 \ + --hash=sha256:06e9abc0e4c9ab4779c74ad99c3fc10d3967d03114449acc2c2762ad4472b8ca \ + --hash=sha256:0b9ac09853b2a3e0d0082104036579809679e7715671cfbf89d83c1cb2a30f58 \ + --hash=sha256:0d47840dc05e0ba04fe2e26f15126de7c755496d5a8aae4a08bda4dd8d646c54 \ + --hash=sha256:0f649fa32fe734c4abdfd4edbb8381c74abf5f34bc0b3271ce687b23729299ed \ + --hash=sha256:107ac60d1bfdc3edb53be75e2a52aff7481b92817cfdddd9b4519ccf0e54a6ff \ + --hash=sha256:11175910f62b2b8c055f2b089e0fedd694fe2be3941b3e2633653bc51064c528 \ + --hash=sha256:12bd4bc2c632742c7ce20db48e0d99afdc05e03f0b4c1af90542e05b809a03d9 \ + --hash=sha256:16f8740eb6dbacc7113e3097b0a36065a02e37b47c936b551805d40340fb9971 \ + --hash=sha256:1c0e8fae5b27caa34177bdfa5a960c46ff2f78ee2d45c6db15ae3f64ecadde14 \ + --hash=sha256:2c54e23836650bdf2c18222c87f6f840d4943944146ca479858404fedeb9f9af \ + --hash=sha256:3367007ad1951fde612bf65b0dffc8fd681a4ab98ac86957d16491400d661302 \ + --hash=sha256:36362386b813fa6c9146da6149a001b7bd063dabc4d49522a1f7aa65b725c7ec \ + --hash=sha256:39807cbcbe406efca2a233884e169d056c35aa7e9f343d4e78665246a332f597 \ + --hash=sha256:39cdf8d141d6d44e8d5a12a8569d5a227f645c87df4f92179bd06e2e2705e76b \ + --hash=sha256:3b2c3502603fab52d7619b882c25a6850b766ebd1b18de3df23b2f939360e1bd \ + --hash=sha256:3ccf2716add72f80714b9a63899b67fa711b654be3fcdd34fa391d2d274ce767 \ + --hash=sha256:3fef4f844d2290ee0ba57addcec17eec9e3df73f10a2748485dfd6a3a188cc0f \ + --hash=sha256:4023e2efc35a30e66e938de5aef42b520c20e7eda7bb5fb12c35e5d09a4c43f6 \ + --hash=sha256:4a3ee019a9befe84fa3e917a2dd378807e423d013377a884c1970a3c2792d293 \ + --hash=sha256:4a8bf76e3182797c6b1afa5b822d1d5802ff30284abe4599e1247be4fd6b03be \ + --hash=sha256:4a992f702c9be9c72fa46f01ca6e18d131906a7180950958f766c2aa294d4b41 \ + --hash=sha256:4c34d4f73ea738223a094d8e0ffd6d2c1a1b4c175da34d6b0de3d8d69bee6bcc \ + --hash=sha256:4cd1bccf99d3ef1ab6ba835308ad85be040e6a11b0977ef7ea8c8005f01a3c29 \ + --hash=sha256:4ef80829117a8061f974b2fda8ec799717242353bff55f8a29411794d635d964 \ + --hash=sha256:58837f9d221744d4c92d2cf7201c6acd19623b50c643b56992cbd2b745485d3d \ + --hash=sha256:5a8f91c64f390ecee09ff793319f30a0f32492e99f5dc1c72bc361f23ccd0a9a \ + --hash=sha256:5addc9d0209a9afca5fc070f93b726bf7003bd63a427f65ef797a931782e7edc \ + --hash=sha256:6239d4e2e0b52c8bd38c51b760cd870069f0bdf99700a62cd509d7a031749a55 \ + --hash=sha256:66e2fe786ef28da2b28e222c89502b2af984858091675044d93cb50e6f46d7af \ + --hash=sha256:69c0771ca5653c7d4b65203cbfc5e66db9375f1078689459fe196fe08b7b4930 \ + --hash=sha256:6ac965a998e1388e6ff2e9781f499ad1eaa41e962a40d11c7823c9952c77123e \ + --hash=sha256:6c56c3d47da04f921b73ff9415fbaa939f684d47293f071aa9cbb13c94afc17d \ + --hash=sha256:6f85739e80d13644b981a88f529d79c5bdf646b460ba190bffcaf6d57b2a9863 \ + --hash=sha256:706e7b739fdd17cb89e1fbf712d9dc21311fc2333f6d435eac2d4ee81985098c \ + --hash=sha256:741ba2f511cc9626b7561a440f87d658aabb3d6b744a86a3c025f866b4d19e7f \ + --hash=sha256:7434a61b158be563c1362d9071358f8ab91b8d928728cd2882af060481244c9e \ + --hash=sha256:76066d7ff61ba6bf3cb5efe2428fc82aac91802844c022d849a1f0f53820502d \ + --hash=sha256:7979b834ec7a33aafae34a90aad9f914c41fd6eaa8474e66953f3f6f7cbd4368 \ + --hash=sha256:7eece6fbd3eae4a92d7c748ae825cbc1ee41a89bb1c3db05b5578ed3cfcfd7cb \ + --hash=sha256:7ef1e014eed78ab650bef9a6a9cbe50b052c0aebe553fb2881e0453717573f52 \ + --hash=sha256:81dce2ddc9f6e8f543d94b05d56e70d03a0774d32f6cca53e978dc01e4fc75b8 \ + --hash=sha256:82fcc1f1cc3ff1ab8a57ba619b149b907072e750815c5ba63e7aa2e1163384a4 \ + --hash=sha256:8d1f21af4c1539051049796a0f50aa342f9a27cde57318f2fc41ed50b0dbc4ac \ + --hash=sha256:90a79bce019c442604662d17bf69df99090e24cdc6ad95b18b6725c2988a490e \ + --hash=sha256:9145f092b5d1977ec8c0ab46e7b3381b2fd069957b9862a43bd383e5c01d18c2 \ + --hash=sha256:91dc1d531f80c862441d7b66c4505cd6ea9d312f01fb2f4654f40c6fdf5cc37a \ + --hash=sha256:979c24cbefaf2420c4e377ecd1f165ea08cc3d1fbb44bdc51bccbbf7c66a2cb4 \ + --hash=sha256:994645a46c6a740ee8ce8df7911d4aee458d9b1bc5639bc968226763d07f00fa \ + --hash=sha256:9b98b7681a9437262947f41c7fac567c7e1f6eddd94b0483596d320092004533 \ + --hash=sha256:9c6b4d23c04831e3ab61717a707a5d763b300213db49ca680edf8bf13ab5d91b \ + --hash=sha256:9c6d0ced3c06d0f183b73d3c5920727268d2201aa0fe6d55c60d68c792ff3588 \ + --hash=sha256:9fd88f373cb71e6b59b7fa597e47e518282455c2734fd4306a05ca219a1991b0 \ + --hash=sha256:a8f4e49fc3ce020f65411432183e6775f24e02dff617281094ba6ab079ef0915 \ + --hash=sha256:a9e908ef5889cda4de038892b9accc36d33d72fb3e12c747e2799a0e806ec841 \ + --hash=sha256:ad08a69728ff3c79866d729b095872afe1e0557251da4abb2c5faff15a91d19a \ + --hash=sha256:adbccd17dcaff65704c856bd29951c58a1bd4b2b0f8ad6b826dbd543fe740988 \ + --hash=sha256:b0c7d2f698e83f15228ba41c135501cfe7d5740181d5903e250e47f617eb4292 \ + --hash=sha256:b3ab05a182c7937fb374f7e946f04fb23a0c0699c0450e9fb02ef567412d2fa3 \ + --hash=sha256:b6104f9a46bd8743e4f738afef69b153c4b8b592d35ae46db07fc28ae3d5fb7c \ + --hash=sha256:ba7cd6dc4d585ea544c1412019921570ebd8a597fabf475acc4528210d7c4a6f \ + --hash=sha256:bc72c231f5449d86d6c7d9cc7cd819b6eb30134bb770b8cfdc0765e48ef9c420 \ + --hash=sha256:bce8814b076f0ce5766dc87d5a056b0e9437b8e0cd351b9a6c4e1134a7dfbda9 \ + --hash=sha256:be5e22bbb67924dea15039c3282fa4cc6cdfbe0cbbd1c0515f9223186fc2ec5f \ + --hash=sha256:be6b7b8d42d3090b6c80793524fa66c57ad7ee3fe9722b258aec6d0672543fd0 \ + --hash=sha256:bfe50b61bab1b1ec260fa7cd91106fa9fece57e6beba05630afe27c71259c59b \ + --hash=sha256:bff507ae210371d4b1fe316d03433ac099f184d570a1a611e541923f78f05037 \ + --hash=sha256:c148bec483cc4b421562b4bcedb8e28a3b84fcc8f0aa4418e10898f3c2c0eb9b \ + --hash=sha256:c15ad0aee158a15e17e0495e1e18741573d04eb6da06d8b84af726cfc1ed02ee \ + --hash=sha256:c2169b2dcabf4e608416f7f9468737583ce5f0a6e8677c4efbf795ce81109d7c \ + --hash=sha256:c55853684fe08d4897c37dfc5faeff70607a5f1806c8be148f1695be4a63414b \ + --hash=sha256:c65a3b5330b54103e7d21cac3f6bf3900d46f6d50138d73343d9e5b2900b2353 \ + --hash=sha256:c7964c2183c3e6cce3f497e3a9f49d182e969f2dc3aeeadfa18945ff7bdd7051 \ + --hash=sha256:cc3f1c053b73f20c7ad88b0d1d23be7e7b3901229ce89f5000a8399746a6e039 \ + --hash=sha256:ce615c92d90df8373d9e13acddd154152645c0dc060871abf6bd43809673d20a \ + --hash=sha256:d29338556a59423d9ff7b6eb0cb89ead2b0875e08fe522f3e068b955c3e7b59b \ + --hash=sha256:d8a993c0a0ffd5f2d3bda23d0cd75e7086736f8f8268de8a82fbc4bd0ac6791e \ + --hash=sha256:d9c727bbcf0065cbb20f39d2b4f932f8fa1631c3e01fcedc979bd4f51fe051c5 \ + --hash=sha256:dac37cf08fcf2094159922edc7a2784cfcc5c70f8354469f79ed085f0328ebdf \ + --hash=sha256:dd829712de97753367153ed84f2de752b86cd1f7a88b55a3a775eb52eafe8a94 \ + --hash=sha256:e54ddd0bb8fb626aa1f9ba7b36629564544954fff9669b15da3610c22b9a0991 \ + --hash=sha256:e77c90ab5997e85901da85131fd36acd0ed2221368199b65f0d11bca44549711 \ + --hash=sha256:ebedc192abbc7fd13c5ee800e83a6df252bec691eb2c4bedc9f8b2e2903f5e2a \ + --hash=sha256:ef71561f82a89af6cfcbee47f0fabfdb6e63788a9258e913955d89fdd96902ab \ + --hash=sha256:f0a47efb1dbef13af9c9a54a94a0b814902e547b7f21acb29434504d18f36e3a \ + --hash=sha256:f4f2ca6df64cbdd27f27b34f35adb640b5d2d77264228554e68deda54456eb11 \ + --hash=sha256:fb02e4257376ae25c6dd95a5aec377f9b18c09be6ebdefa7ad209b9137b73d48 + # via + # nltk + # sacrebleu + # transformers +requests[socks]==2.31.0 \ + --hash=sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f \ + --hash=sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1 + # via + # gdown + # huggingface-hub + # requests-oauthlib + # tensorboard + # tensorflow-datasets + # tfds-nightly + # transformers +requests-oauthlib==1.3.1 \ + --hash=sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5 \ + --hash=sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a + # via google-auth-oauthlib +rich==11.2.0 \ + --hash=sha256:1a6266a5738115017bb64a66c59c717e7aa047b3ae49a011ede4abdeffc6536e \ + --hash=sha256:d5f49ad91fb343efcae45a2b2df04a9755e863e50413623ab8c9e74f05aee52b + # via flax +rouge-score==0.1.2 \ + --hash=sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04 + # via t5 +rsa==4.9 \ + --hash=sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7 \ + --hash=sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21 + # via google-auth +sacrebleu==2.3.1 \ + --hash=sha256:352227b8ca9e04ed509266d1fee6c8cff0ea1417c429f8c684645ad2db8b02e7 \ + --hash=sha256:7969b294f15dae84d80fb2b76d30c83b245f49f4ecb1cac79acb553eb93cb537 + # via t5 +safetensors==0.4.0 \ + --hash=sha256:002301c1afa32909f83745b0c124d002e7ae07e15671f3b43cbebd0ffc5e6037 \ + --hash=sha256:003077ec85261d00061058fa12e3c1d2055366b02ce8f2938929359ffbaff2b8 \ + --hash=sha256:00a9b157be660fb7ba88fa2eedd05ec93793a5b61e43e783e10cb0b995372802 \ + --hash=sha256:0219cea445177f6ad1f9acd3a8d025440c8ff436d70a4a7c7ba9c36066aa9474 \ + --hash=sha256:0620ab0d41e390ccb1c4ea8f63dc00cb5f0b96a5cdd3cd0d64c21765720c074a \ + --hash=sha256:0f45230f20a206e5e4c7f7bbf9342178410c6f8b0af889843aa99045a76f7691 \ + --hash=sha256:10b65cd3ad79f5d0daf281523b4146bc271a34bb7430d4e03212e0de8622dab8 \ + --hash=sha256:114decacc475a6a9e2f9102a00c171d113ddb5d35cb0bda0db2c0c82b2eaa9ce \ + --hash=sha256:1db87155454c168aef118d5657a403aee48a4cb08d8851a981157f07351ea317 \ + --hash=sha256:1e2f9c69b41d03b4826ffb96b29e07444bb6b34a78a7bafd0b88d59e8ec75b8a \ + --hash=sha256:2077801800b4b13301d8d6290c7fb5bd60737320001717153ebc4371776643b5 \ + --hash=sha256:2289ae6dbe6d027ecee016b28ced13a2e21a0b3a3a757a23033a2d1c0b1bad55 \ + --hash=sha256:232029f0a9fa6fa1f737324eda98a700409811186888536a2333cbbf64e41741 \ + --hash=sha256:25cd407955bad5340ba17f9f8ac789a0d751601a311e2f7b2733f9384478c95e \ + --hash=sha256:27a24ca8822c469ee452db4c13418ba983315a0d863c018a9af15f2305eac38c \ + --hash=sha256:2b9b94133ed2ae9dda0e95dcace7b7556eba023ffa4c4ae6df8f99377f571d6a \ + --hash=sha256:2f27b8ef814c5fb43456caeb7f3cbb889b76115180aad1f42402839c14a47c5b \ + --hash=sha256:2f99d90c91b7c76b40a862acd9085bc77f7974a27dee7cfcebe46149af5a99a1 \ + --hash=sha256:38032078ed9fea52d06584e441bccc73fb475c4581600c6d6166de2fe2deb3d1 \ + --hash=sha256:3910fb5bf747413b59f1a34e6d2a993b589fa7d919709518823c70efaaa350bd \ + --hash=sha256:3b6c1316ffde6cb4bf22c7445bc9fd224b4d1b9dd7320695f5611c89e802e4b6 \ + --hash=sha256:3bfed574f6b1e7e7fe1f17213278875ef6c6e8b1582ab6eda93947db1178cae6 \ + --hash=sha256:3ebf6bcece5d5d1bd6416472f94604d2c834ca752ac60ed42dba7157e595a990 \ + --hash=sha256:3f4d90c79a65ba2fe2ff0876f6140748f0a3ce6a21e27a35190f4f96321803f8 \ + --hash=sha256:3f667a4c12fb593f5f66ce966cb1b14a7148898b2b1a7f79e0761040ae1e3c51 \ + --hash=sha256:40d7cf03493bfe75ef62e2c716314474b28d9ba5bf4909763e4b8dd14330c01a \ + --hash=sha256:435fd136a42492b280cb55126f9ce9535b35dd49df2c5d572a5945455a439448 \ + --hash=sha256:44f84373e42183bd56a13a1f2d8acb1db7fedaeffbd83e79cec861477eee1af4 \ + --hash=sha256:469360b9451db10bfed3881378d5a71b347ecb1ab4f42367d77b8164a13af70b \ + --hash=sha256:48b92059b1a4ad163024d4f526e0e73ebe2bb3ae70537e15e347820b4de5dc27 \ + --hash=sha256:491b3477e4d0d4599bb75d79da4b75af2e6ed9b1f6ec2b715991f0bc927bf09a \ + --hash=sha256:4936096a57c62e84e200f92620a536be067fc5effe46ecc7f230ebb496ecd579 \ + --hash=sha256:495dcaea8fbab70b927d2274e2547824462737acbf98ccd851a71124f779a5c6 \ + --hash=sha256:4b2aa57b5a4d576f3d1dd6e56980026340f156f8a13c13016bfac4e25295b53f \ + --hash=sha256:4d512525a8e05a045ce6698066ba0c5378c174a83e0b3720a8c7799dc1bb06f3 \ + --hash=sha256:4fe9e3737b30de458225a23926219ca30b902ee779b6a3df96eaab2b6d625ec2 \ + --hash=sha256:59d2e10b7e0cd18bb73ed7c17c624a5957b003b81345e18159591771c26ee428 \ + --hash=sha256:5daa05058f7dce85b5f9f60c4eab483ed7859d63978f08a76e52e78859ff20ca \ + --hash=sha256:5f9909512bcb6f712bdd04c296cdfb0d8ff73d258ffc5af884bb62ea02d221e0 \ + --hash=sha256:61a00f281391fae5ce91df70918bb61c12d2d514a493fd8056e12114be729911 \ + --hash=sha256:6686ce01b8602d55a7d9903c90d4a6e6f90aeb6ddced7cf4605892d0ba94bcb8 \ + --hash=sha256:67762d36ae088c73d4a3c96bfc4ea8d31233554f35b6cace3a18533238d462ea \ + --hash=sha256:67ab171eeaad6972d3971c53d29d53353c67f6743284c6d637b59fa3e54c8a94 \ + --hash=sha256:6b563a14c43614815a6b524d2e4edeaace50b717f7e7487bb227dd5b68350f5a \ + --hash=sha256:6c42623ae7045615d9eaa6877b9df1db4e9cc71ecc14bcc721ea1e475dddd595 \ + --hash=sha256:6c5556c2ec75f5a6134866eddd7341cb36062e6edaea343478a279591b63ddba \ + --hash=sha256:72ddb741dd5fe42521db76a70e012f76995516a12e7e0ef26be03ea9be77802a \ + --hash=sha256:73e7696dcf3f72f99545eb1abe6106ad65ff1f62381d6ce4b34be3272552897a \ + --hash=sha256:74e2a448ffe19be188b457b130168190ee73b5a75e45ba96796320c1f5ae35d2 \ + --hash=sha256:79a983b09782dacf9a1adb19bb98f4a8f6c3144108939f572c047b5797e43cf5 \ + --hash=sha256:79dd46fb1f19282fd12f544471efb97823ede927cedbf9cf35550d92b349fdd2 \ + --hash=sha256:7a524382b5c55b5fbb168e0e9d3f502450c8cf3fb81b93e880018437c206a482 \ + --hash=sha256:7abe0e157a49a75aeeccfbc4f3dac38d8f98512d3cdb35c200f8e628dc5773cf \ + --hash=sha256:7b2d6101eccc43c7be0cb052f13ceda64288b3d8b344b988ed08d7133cbce2f3 \ + --hash=sha256:7ffc736039f08a9ca1f09816a7481b8e4469c06e8f8a5ffa8cb67ddd79e6d77f \ + --hash=sha256:806379f37e1abd5d302288c4b2f4186dd7ea7143d4c7811f90a8077f0ae8967b \ + --hash=sha256:80cb8342f00f3c41b3b93b1a599b84723280d3ac90829bc62262efc03ab28793 \ + --hash=sha256:82e8fc4e3503cd738fd40718a430fe0e5ce6e7ff91a73d6ce628bbb89c41e8ce \ + --hash=sha256:87b328ee1591adac332543e1f5fc2c2d7f149b745ebb0d58d7850818ff9cee27 \ + --hash=sha256:8a6abfe67692f81b8bdb99c837f28351c17e624ebf136970c850ee989c720446 \ + --hash=sha256:8e735b0f79090f6855b55e205e820b7b595502ffca0009a5c13eef3661ce465b \ + --hash=sha256:8f2ca939bbd8fb2f4dfa28e39a146dad03bc9325e9fc831b68f7b98f69a5a2f1 \ + --hash=sha256:964ef166a286ce3b023d0d0bd0e21d440a1c8028981c8abdb136bc7872ba9b3d \ + --hash=sha256:9849ea60c7e840bfdd6030ad454d4a6ba837b3398c902f15a30460dd6961c28c \ + --hash=sha256:9b8fd6cc2f3bda444a048b541c843c7b7fefc89c4120d7898ea7d5b026e93891 \ + --hash=sha256:9e583fa68e5a07cc859c4e13c1ebff12029904aa2e27185cf04a1f57fe9a81c4 \ + --hash=sha256:9ed55f4a20c78ff3e8477efb63c8303c2152cdfb3bfea4d025a80f54d38fd628 \ + --hash=sha256:a54c21654a47669b38e359e8f852af754b786c9da884bb61ad5e9af12bd71ccb \ + --hash=sha256:a738970a367f39249e2abb900d9441a8a86d7ff50083e5eaa6e7760a9f216014 \ + --hash=sha256:a78ffc0795d3595cd9e4d453502e35f764276c49e434b25556a15a337db4dafc \ + --hash=sha256:a86565a5c112dd855909e20144947b4f53abb78c4de207f36ca71ee63ba5b90d \ + --hash=sha256:acf0180283c2efae72f1d8c0a4a7974662091df01be3aa43b5237b1e52ed0a01 \ + --hash=sha256:b561fbc044db7beff2ece0ec219a291809d45a38d30c6b38e7cc46482582f4ba \ + --hash=sha256:b69554c143336256260eceff1d3c0969172a641b54d4668489a711b05f92a2c0 \ + --hash=sha256:b6b60a58a8f7cc7aed3b5b73dce1f5259a53c83d9ba43a76a874e6ad868c1b4d \ + --hash=sha256:b985953c3cf11e942eac4317ef3db3da713e274109cf7cfb6076d877054f013e \ + --hash=sha256:bc1fa8d067733cb67f22926689ee808f08afacf7700d2ffb44efae90a0693eb1 \ + --hash=sha256:bd63d83a92f1437a8b0431779320376030ae43ace980bea5686d515de0784100 \ + --hash=sha256:bf6458959f310f551cbbeef2255527ade5f783f952738e73e4d0136198cc3bfe \ + --hash=sha256:c42bdea183dbaa99e2f0e6120dc524df79cf4289a6f90f30a534444ef20f49fa \ + --hash=sha256:c4a0a47c8640167792d8261ee21b26430bbc39130a7edaad7f4c0bc05669d00e \ + --hash=sha256:c68132727dd86fb641102e494d445f705efe402f4d5e24b278183a15499ab400 \ + --hash=sha256:c8f194f45ab6aa767993c24f0aeb950af169dbc5d611b94c9021a1d13b8a1a34 \ + --hash=sha256:cbc4a4da01143472323c145f3c289e5f6fabde0ac0a3414dabf912a21692fff4 \ + --hash=sha256:cd02b495ba0814619f40bda46771bb06dbbf1d42524b66fa03b2a736c77e4515 \ + --hash=sha256:cef7bb5d9feae7146c3c3c7b3aef7d2c8b39ba7f5ff4252d368eb69462a47076 \ + --hash=sha256:cf8fdca709b2470a35a59b1e6dffea75cbe1214b22612b5dd4c93947697aea8b \ + --hash=sha256:d33d29e846821f0e4f92614022949b09ccf063cb36fe2f9fe099cde1efbfbb87 \ + --hash=sha256:d8c4f5ed4ede384dea8c99bae76b0718a828dbf7b2c8ced1f44e3b9b1a124475 \ + --hash=sha256:db7bb48ca9e90bb9526c71b388d38d8de160c0354f4c5126df23e8701a870dcb \ + --hash=sha256:dcaa40bc363edda145db75cd030f3b1822e5478d550c3500a42502ecef32c959 \ + --hash=sha256:e7916e814a90008de767b1c164a1d83803693c661ffe9af5a697b22e2752edb0 \ + --hash=sha256:e853e189ba7d47eaf561094586692ba2bbdd258c096f1755805cac098de0e6ab \ + --hash=sha256:ed50f239b0ce7ae85b078395593b4a351ede7e6f73af25f4873e3392336f64c9 \ + --hash=sha256:f0daa788273d683258fb1e4a5e16bef4486b2fca536451a2591bc0f4a6488895 \ + --hash=sha256:f5f75fa97ccf32a3c7af476c6a0e851023197d3c078f6de3612008fff94735f9 \ + --hash=sha256:f8d2416734e850d5392afffbcb2b8985ea29fb171f1cb197e2ae51b8e35d6438 \ + --hash=sha256:fdc34027b545a69be3d4220c140b276129523e4e46db06ad1a0b60d6a4cf9214 + # via transformers +scikit-learn==1.3.1 \ + --hash=sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043 \ + --hash=sha256:0ce9233cdf0cdcf0858a5849d306490bf6de71fa7603a3835124e386e62f2311 \ + --hash=sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122 \ + --hash=sha256:14e8775eba072ab10866a7e0596bc9906873e22c4c370a651223372eb62de180 \ + --hash=sha256:1a231cced3ee3fa04756b4a7ab532dc9417acd581a330adff5f2c01ac2831fcf \ + --hash=sha256:1ec668ce003a5b3d12d020d2cde0abd64b262ac5f098b5c84cf9657deb9996a8 \ + --hash=sha256:3153612ff8d36fa4e35ef8b897167119213698ea78f3fd130b4068e6f8d2da5a \ + --hash=sha256:4d379f2b34096105a96bd857b88601dffe7389bd55750f6f29aaa37bc6272eb5 \ + --hash=sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca \ + --hash=sha256:58b0c2490eff8355dc26e884487bf8edaccf2ba48d09b194fb2f3a026dd64f9d \ + --hash=sha256:66f7bb1fec37d65f4ef85953e1df5d3c98a0f0141d394dcdaead5a6de9170347 \ + --hash=sha256:6bb9490fdb8e7e00f1354621689187bef3cab289c9b869688f805bf724434755 \ + --hash=sha256:7d8dee8c1f40eeba49a85fe378bdf70a07bb64aba1a08fda1e0f48d27edfc3e6 \ + --hash=sha256:8454d57a22d856f1fbf3091bd86f9ebd4bff89088819886dc0c72f47a6c30652 \ + --hash=sha256:845f81c7ceb4ea6bac64ab1c9f2ce8bef0a84d0f21f3bece2126adcc213dfecd \ + --hash=sha256:8d993fb70a1d78c9798b8f2f28705bfbfcd546b661f9e2e67aa85f81052b9c53 \ + --hash=sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac \ + --hash=sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236 \ + --hash=sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26 \ + --hash=sha256:a7135a03af71138669f19bc96e7d0cc8081aed4b3565cc3b131135d65fc642ba \ + --hash=sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028 \ + --hash=sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be \ + --hash=sha256:ccbbedae99325628c1d1cbe3916b7ef58a1ce949672d8d39c8b190e10219fd32 \ + --hash=sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4 \ + --hash=sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f \ + --hash=sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217 + # via t5 +scipy==1.10.0 \ + --hash=sha256:0490dc499fe23e4be35b8b6dd1e60a4a34f0c4adb30ac671e6332446b3cbbb5a \ + --hash=sha256:0ab2a58064836632e2cec31ca197d3695c86b066bc4818052b3f5381bfd2a728 \ + --hash=sha256:151f066fe7d6653c3ffefd489497b8fa66d7316e3e0d0c0f7ff6acca1b802809 \ + --hash=sha256:16ba05d3d1b9f2141004f3f36888e05894a525960b07f4c2bfc0456b955a00be \ + --hash=sha256:27e548276b5a88b51212b61f6dda49a24acf5d770dff940bd372b3f7ced8c6c2 \ + --hash=sha256:2ad449db4e0820e4b42baccefc98ec772ad7818dcbc9e28b85aa05a536b0f1a2 \ + --hash=sha256:2f9ea0a37aca111a407cb98aa4e8dfde6e5d9333bae06dfa5d938d14c80bb5c3 \ + --hash=sha256:38bfbd18dcc69eeb589811e77fae552fa923067fdfbb2e171c9eac749885f210 \ + --hash=sha256:3afcbddb4488ac950ce1147e7580178b333a29cd43524c689b2e3543a080a2c8 \ + --hash=sha256:42ab8b9e7dc1ebe248e55f54eea5307b6ab15011a7883367af48dd781d1312e4 \ + --hash=sha256:441cab2166607c82e6d7a8683779cb89ba0f475b983c7e4ab88f3668e268c143 \ + --hash=sha256:4bd0e3278126bc882d10414436e58fa3f1eca0aa88b534fcbf80ed47e854f46c \ + --hash=sha256:4df25a28bd22c990b22129d3c637fd5c3be4b7c94f975dca909d8bab3309b694 \ + --hash=sha256:5cd7a30970c29d9768a7164f564d1fbf2842bfc77b7d114a99bc32703ce0bf48 \ + --hash=sha256:6e4497e5142f325a5423ff5fda2fff5b5d953da028637ff7c704378c8c284ea7 \ + --hash=sha256:6faf86ef7717891195ae0537e48da7524d30bc3b828b30c9b115d04ea42f076f \ + --hash=sha256:954ff69d2d1bf666b794c1d7216e0a746c9d9289096a64ab3355a17c7c59db54 \ + --hash=sha256:9b878c671655864af59c108c20e4da1e796154bd78c0ed6bb02bc41c84625686 \ + --hash=sha256:b901b423c91281a974f6cd1c36f5c6c523e665b5a6d5e80fcb2334e14670eefd \ + --hash=sha256:c8b3cbc636a87a89b770c6afc999baa6bcbb01691b5ccbbc1b1791c7c0a07540 \ + --hash=sha256:e096b062d2efdea57f972d232358cb068413dc54eec4f24158bcbb5cb8bddfd8 + # via + # -r requirements.in + # jax + # jaxlib + # mesh-tensorflow + # scikit-learn + # t5 +sentencepiece==0.1.99 \ + --hash=sha256:004e6a621d4bc88978eecb6ea7959264239a17b70f2cbc348033d8195c9808ec \ + --hash=sha256:019e7535108e309dae2b253a75834fc3128240aa87c00eb80732078cdc182588 \ + --hash=sha256:0b0f55d0a0ee1719b4b04221fe0c9f0c3461dc3dabd77a035fa2f4788eb3ef9a \ + --hash=sha256:0eaf3591dd0690a87f44f4df129cf8d05d8a4029b5b6709b489b8e27f9a9bcff \ + --hash=sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73 \ + --hash=sha256:14b0eccb7b641d4591c3e12ae44cab537d68352e4d3b6424944f0c447d2348d5 \ + --hash=sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f \ + --hash=sha256:18e800f206cd235dc27dc749299e05853a4e4332e8d3dfd81bf13d0e5b9007d9 \ + --hash=sha256:27b866b5bd3ddd54166bbcbf5c8d7dd2e0b397fac8537991c7f544220b1f67bc \ + --hash=sha256:2ae1c40cda8f9d5b0423cfa98542735c0235e7597d79caf318855cdf971b2280 \ + --hash=sha256:2d95e19168875b70df62916eb55428a0cbcb834ac51d5a7e664eda74def9e1e0 \ + --hash=sha256:33e6f690a1caebb4867a2e367afa1918ad35be257ecdb3455d2bbd787936f155 \ + --hash=sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd \ + --hash=sha256:38efeda9bbfb55052d482a009c6a37e52f42ebffcea9d3a98a61de7aee356a28 \ + --hash=sha256:445b0ec381af1cd4eef95243e7180c63d9c384443c16c4c47a28196bd1cda937 \ + --hash=sha256:47c378146928690d1bc106fdf0da768cebd03b65dd8405aa3dd88f9c81e35dba \ + --hash=sha256:57efcc2d51caff20d9573567d9fd3f854d9efe613ed58a439c78c9f93101384a \ + --hash=sha256:62e24c81e74bd87a6e0d63c51beb6527e4c0add67e1a17bac18bcd2076afcfeb \ + --hash=sha256:6a904c46197993bd1e95b93a6e373dca2f170379d64441041e2e628ad4afb16f \ + --hash=sha256:6c030b081dc1e1bcc9fadc314b19b740715d3d566ad73a482da20d7d46fd444c \ + --hash=sha256:6d3c56f24183a1e8bd61043ff2c58dfecdc68a5dd8955dc13bab83afd5f76b81 \ + --hash=sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7 \ + --hash=sha256:84dbe53e02e4f8a2e45d2ac3e430d5c83182142658e25edd76539b7648928727 \ + --hash=sha256:85b476406da69c70586f0bb682fcca4c9b40e5059814f2db92303ea4585c650c \ + --hash=sha256:8a1abff4d1ff81c77cac3cc6fefa34fa4b8b371e5ee51cb7e8d1ebc996d05983 \ + --hash=sha256:8a321866c2f85da7beac74a824b4ad6ddc2a4c9bccd9382529506d48f744a12c \ + --hash=sha256:9832f08bb372d4c8b567612f8eab9e36e268dff645f1c28f9f8e851be705f6d1 \ + --hash=sha256:9ba142e7a90dd6d823c44f9870abdad45e6c63958eb60fe44cca6828d3b69da2 \ + --hash=sha256:a2a0260cd1fb7bd8b4d4f39dc2444a8d5fd4e0a0c4d5c899810ef1abf99b2d45 \ + --hash=sha256:b133e8a499eac49c581c3c76e9bdd08c338cc1939e441fee6f92c0ccb5f1f8be \ + --hash=sha256:b7b1a9ae4d7c6f1f867e63370cca25cc17b6f4886729595b885ee07a58d3cec3 \ + --hash=sha256:baed1a26464998f9710d20e52607c29ffd4293e7c71c6a1f83f51ad0911ec12c \ + --hash=sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0 \ + --hash=sha256:c42f753bcfb7661c122a15b20be7f684b61fc8592c89c870adf52382ea72262d \ + --hash=sha256:c6890ea0f2b4703f62d0bf27932e35808b1f679bdb05c7eeb3812b935ba02001 \ + --hash=sha256:c84ce33af12ca222d14a1cdd37bd76a69401e32bc68fe61c67ef6b59402f4ab8 \ + --hash=sha256:c8843d23a0f686d85e569bd6dcd0dd0e0cbc03731e63497ca6d5bacd18df8b85 \ + --hash=sha256:cfbcfe13c69d3f87b7fcd5da168df7290a6d006329be71f90ba4f56bc77f8561 \ + --hash=sha256:d0f644c9d4d35c096a538507b2163e6191512460035bf51358794a78515b74f7 \ + --hash=sha256:d89adf59854741c0d465f0e1525b388c0d174f611cc04af54153c5c4f36088c4 \ + --hash=sha256:db361e03342c41680afae5807590bc88aa0e17cfd1a42696a160e4005fcda03b \ + --hash=sha256:ed6ea1819fd612c989999e44a51bf556d0ef6abfb553080b9be3d347e18bcfb7 \ + --hash=sha256:f90d73a6f81248a909f55d8e6ef56fec32d559e1e9af045f0b0322637cb8e5c7 \ + --hash=sha256:fa16a830416bb823fa2a52cbdd474d1f7f3bba527fd2304fb4b140dad31bb9bc \ + --hash=sha256:fb71af492b0eefbf9f2501bec97bcd043b6812ab000d119eaf4bd33f9e283d03 + # via + # -r requirements.in + # seqio + # seqio-nightly + # t5 +seqio==0.0.18 \ + --hash=sha256:2b2c4a50e507e2cbe2fc171005e2ff35ad50ee7899b2929a3d532b68b4c38424 \ + --hash=sha256:856fec6be4f2f5f5ffafb10d36b41f6a5b90670ee964dd600a341c6319349c7b + # via -r requirements.in +seqio-nightly==0.0.17.dev20231013 \ + --hash=sha256:60fcc81f439a486a12e46cbce193e88af1440d19ffeb19eb64eac8bc8e99488f \ + --hash=sha256:9474513d7f6ebf2d62b762bcfd82822dbb7f40b6c39e9e07bee292e85ffa6d3a + # via t5 +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 + # via + # astunparse + # gdown + # google-pasta + # mesh-tensorflow + # ml-collections + # promise + # python-dateutil + # rouge-score + # t5 + # tensorflow +soupsieve==2.5 \ + --hash=sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690 \ + --hash=sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7 + # via beautifulsoup4 +t5==0.9.4 \ + --hash=sha256:53e40efdfc4d8f614cd8d38301e39d5c146f9d5fd66b6823d9c181b60370b2b5 + # via -r requirements.in +tabulate==0.9.0 \ + --hash=sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c \ + --hash=sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f + # via sacrebleu +tensorboard==2.13.0 \ + --hash=sha256:ab69961ebddbddc83f5fa2ff9233572bdad5b883778c35e4fe94bf1798bd8481 + # via tensorflow +tensorboard-data-server==0.7.1 \ + --hash=sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594 \ + --hash=sha256:9938bd39f5041797b33921066fba0eab03a0dd10d1887a05e62ae58841ad4c3f \ + --hash=sha256:be8d016a1aa394e6198280d4a3dc37898f56467310c5f5e617cac10a783e055a + # via tensorboard +tensorflow==2.13.0 \ + --hash=sha256:00060c5516a61e30c51936084ebc37091d116efe9ae74b2818cbd8b2006218e7 \ + --hash=sha256:06559eeaa69e6561cccbe2d02b015bcec663e875c8bbc4643f55692476e52147 \ + --hash=sha256:076d953a1508dc58bf95f30f58bcc9ee364b1353c61e143cb20c2dada91afb05 \ + --hash=sha256:11ad6a7ff49b4a690fbf37a5adaf28ba2686350a859c5f13c58dc8d2cc670375 \ + --hash=sha256:19ee67901702b26787ad685cca65730c163c101c0c2f238a2584d714e0fa8c25 \ + --hash=sha256:2822ac48c38d69b7fb104e606dacbd763c4bf5d3b20791f25be16a5076d01912 \ + --hash=sha256:5e0fdadec59de3d11c5b5129ddc38e739bde7aa13095b82e19d4380e14d04999 \ + --hash=sha256:6fff426661d286a4c634da44275d2ea2b951f392f3e65c8603681e7cb040586a \ + --hash=sha256:72d68b8c2f382e2d01b956c8ba516c0a7d5dad98111dd351bf82bfa646aa1c72 \ + --hash=sha256:7a08c0e2938ed5b642a8787678123827477b81d316055d5073fff82fa183eb82 \ + --hash=sha256:89125443e998548059c4e4a129dfab2b1ff7f2fd4c8eaed97842c3cd9b663101 \ + --hash=sha256:948003b5a23b72b3d89746d729e62ec5f01e47460f05521b2211d95069f569ba \ + --hash=sha256:9c04bc3023b6c4cfb9ee9759c3f03f21993891b4c345df52eb5519204fbf28c0 \ + --hash=sha256:b2978b39e8b3919059b5fd9e28508d50a77965d06ed0b537ed71c97de22dabdf \ + --hash=sha256:cbb83561bb7d55859eaefc70c674e58713d4e10c10927423ed836a5289bbfa86 \ + --hash=sha256:de77306c0c22c9d8754f54700752ac3a1efee895c5357308e6594436404bfbc0 \ + --hash=sha256:e0cf94d36ceaba8f158c6e15404a81fd5b3aa4cb04147c674cf55bd1aec78154 \ + --hash=sha256:e8f0b69ee2f800399fc6bc7ec55fecfa33662d136e425485959d90638f32a32a \ + --hash=sha256:fa7abe265cc3ebccc9b405a280bf674824c6d85df5e6ccfa985987b3c9d265b4 \ + --hash=sha256:fb2ff1129c93e853c19897d6a22ed0ec56387f5c6290ec03dec1c6f7b80bc396 + # via + # -r requirements.in + # clu + # tensorflow-text +tensorflow-datasets==4.9.3 \ + --hash=sha256:09cd60eccab0d5a9d15f53e76ee0f1b530ee5aa3665e42be621a4810d9fa5db6 \ + --hash=sha256:90390077dde2c9e4e240754ddfc5bb50b482946d421c8a34677c3afdb0463427 + # via + # -r requirements.in + # clu + # mesh-tensorflow +tensorflow-estimator==2.13.0 \ + --hash=sha256:6f868284eaa654ae3aa7cacdbef2175d0909df9fcf11374f5166f8bf475952aa + # via tensorflow +tensorflow-hub==0.15.0 \ + --hash=sha256:8af12cb2d1fc0d1a9509a620e7589daf173714e99f08aaf090a4748ff20b45c8 + # via tensorflow-text +tensorflow-io-gcs-filesystem==0.34.0 \ + --hash=sha256:027a07553367187f918a99661f63ae0506b91b77a70bee9c7ccaf3920bf7cfe7 \ + --hash=sha256:0dafed144673e1173528768fe208a7c5a6e8edae40208381cac420ee7c918ec9 \ + --hash=sha256:182b0fbde7e9a537fda0b354c28b0b6c035736728de8fe2db7ef49cf90352014 \ + --hash=sha256:2b035f4c92639657b6d376929d550ac3dee9e6c0523eb434eefe0a27bae3d05b \ + --hash=sha256:396bfff61b49f80b86ddebe0c76ae0f2731689cee49ad7d782625180b50b13af \ + --hash=sha256:3f346b287ed2400e09b13cfd8524222fd70a66aadb9164c645286c2087007e9f \ + --hash=sha256:44ad387a812a78e7424bb8bee3820521ae1c044bddf72b1e163e8df95c124a74 \ + --hash=sha256:5813c336b4f7cb0a01ff4cc6cbd3edf11ef67305baf0e3cf634911b702f493f8 \ + --hash=sha256:6e6353123a5b51397950138a118876af833a7db66b531123bb86f82e80ab0e72 \ + --hash=sha256:7f60183473f0ca966451bb1d1bb5dc29b3cf9c74d1d0e7f2ed46760ed56bd4af \ + --hash=sha256:8d8664bddbe4e7b56ce94db8b93ea9077a158fb5e15364e11e29f93015ceea24 \ + --hash=sha256:a17a616d2c7fae83de4424404815843507d40d4eb0d507c636a5493a20c3d958 \ + --hash=sha256:b20622f8572fcb6c93e8f7d626327472f263e47ebd63d2153ef09162ef5ef7b5 \ + --hash=sha256:b9a93fcb01db269bc845a1ced431f3c61201755ce5f9ec4885760f30122276ef \ + --hash=sha256:cbe26c4a3332589c7b724f147df453b5c226993aa8d346a15536358d77b364c4 \ + --hash=sha256:d3feba2dd76f7c188137c34642d68d378f0eed81636cb95090ecb1496722707c \ + --hash=sha256:d831702fbb270996b27cda7fde06e0825b2ea81fd8dd3ead35242f4f8b3889b8 \ + --hash=sha256:ec4604c99cbb5b708f4516dee27aa655abae222b876c98b740f4c2f89dd5c001 \ + --hash=sha256:f211d2b3db8f9931765992b607b71cbfb98c8cd6169079d004a67a94ab10ecb4 + # via tensorflow +tensorflow-metadata==1.14.0 \ + --hash=sha256:5ff79bf96f98c800fc08270b852663afe7e74d7e1f92b50ba1487bfc63894cdb + # via + # tensorflow-datasets + # tfds-nightly +tensorflow-text==2.13.0 \ + --hash=sha256:142b35fc7f633250db2c4810e0e60eadc015292c7dde3fff6189213056f8fd7d \ + --hash=sha256:8af4379cfe8f454d3e8ad38627153e20d852a5ba10591b47c9dbb64490ae4f16 \ + --hash=sha256:94589df89c531f4c2c61029203c45d1f299f3204b47c3a1aa8ff636e7f58dadf \ + --hash=sha256:b98316df6dd576e62a7a56bd093d1488446d84831bacd1cbb46eb339585fd381 \ + --hash=sha256:d764c90ecceb9b603170a5c9b448b1cd369709c24d230913be83156ac4e3a431 \ + --hash=sha256:ef7694623a79793a1db0ac66ab834596ea40f4fe5c7fdd92e402537c9e496bf7 + # via + # seqio + # seqio-nightly +tensorstore==0.1.45 \ + --hash=sha256:05196a0464ce51867f1edd96e992fe01281de283b034d434ca6e81db319368c0 \ + --hash=sha256:0ce1a3d2bdbdb2c1102100ee23fa99a95b0bcdee9773862622d7da833516c8c9 \ + --hash=sha256:2ff6e5177ba2702f348bef3edc37619aa7646e43f33d1a567ba267db455699e4 \ + --hash=sha256:38468c621b2edf09cfdd2df4905890e83f1805c7645ec13e16df5eafabf0e5e5 \ + --hash=sha256:405bf40271eed5632a566cdb935beba87d9896d2f80caf75386febb529ddba45 \ + --hash=sha256:4346ab7afa0963dcaa8e64388a2bedab741c790786b577326a0b174d226c9320 \ + --hash=sha256:4915aee8355ee7dbc6f534d77a28c18001e19696f44f78760ec42845ac51edee \ + --hash=sha256:537805adb06fff2ce9a259b81920af4c34a20f752fa28205e722b7e58a60c790 \ + --hash=sha256:6d7b6cccb96b36356d3e61c4e89972b82123d799cc2ca50f743e30ce45d70739 \ + --hash=sha256:73df4ddafe4da8e0f919ed5a75f48839013da3a99128a719fe730855252051a6 \ + --hash=sha256:8659688ec9d89cdd71046c35b3c84cf92cd8c88251e6068f8a99d6991a965028 \ + --hash=sha256:871a1fde0712a153ac44774ddace3ad841609ff5be792734d44cffb520258e92 \ + --hash=sha256:9bc7cde6318363eb9d35fc6cacb6fcd5d7a03b0ee57bdd69249108c0164692d8 \ + --hash=sha256:a8960f0e546ee493ed67b77998859f0cb94772ea31e865bf76b0c79976ac9204 \ + --hash=sha256:c034fec18b6e3174d26df1cdd91ec67b720fc5de7ef0cc3804017dad8c211622 \ + --hash=sha256:ca212d127fcc4debb9f6b4274d584fe7724b2a349ca9444258a4127878dc3033 \ + --hash=sha256:f38bba6fc0668a950b76752c743b66851c4fc7360857e8b37a4f7a4e9786760b + # via flax +termcolor==2.3.0 \ + --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ + --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a + # via + # tensorflow + # tensorflow-datasets + # tfds-nightly +tfds-nightly==4.9.2.dev202308090034 \ + --hash=sha256:16cb572ec9b602b4202d2dc30c741f8cd1c63d8bb82e7d6024eea3b7e28e7d1c \ + --hash=sha256:a6ad764739f33b04b1bd4104a7373149c10455ce9378cbef6cb830339db2861c + # via + # seqio + # seqio-nightly + # t5 +threadpoolctl==3.2.0 \ + --hash=sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032 \ + --hash=sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355 + # via scikit-learn +tokenizers==0.14.1 \ + --hash=sha256:00df4c5bf25c153b432b98689609b426ae701a44f3d8074dcb619f410bc2a870 \ + --hash=sha256:01d2bd5935642de22a6c6778bb2307f9949cd6eaeeb5c77f9b98f0060b69f0db \ + --hash=sha256:040ee44efc1806900de72b13c1c3036154077d9cde189c9a7e7a50bbbdcbf39f \ + --hash=sha256:04ec1134a18ede355a05641cdc7700f17280e01f69f2f315769f02f7e295cf1e \ + --hash=sha256:08e55920b453c30b46d58accc68a38e8e7488d0c03babfdb29c55d3f39dd2052 \ + --hash=sha256:0c8ee283b249c3c3c201c41bc23adc3be2514ae4121eacdb5c5250a461eaa8c6 \ + --hash=sha256:0ce2f0ff2e5f12ac5bebaa690606395725239265d7ffa35f35c243a379316297 \ + --hash=sha256:102f118fa9b720b93c3217c1e239ed7bc1ae1e8dbfe9b4983a4f2d7b4ce6f2ec \ + --hash=sha256:11284b32f0036fe7ef4b8b00201dda79c00f3fcea173bc0e5c599e09c937ab0f \ + --hash=sha256:117c8da60d1bd95a6df2692926f36de7971baa1d89ff702fae47b6689a4465ad \ + --hash=sha256:1ff516d129f01bb7a4aa95bc6aae88e4d86dd63bfc2d57db9302c2624d1be7cb \ + --hash=sha256:232445e7b85255ccfe68dfd42185db8a3f3349b34ad7068404856c4a5f67c355 \ + --hash=sha256:2539831838ab5393f78a893d7bbf27d5c36e43baf77e91dc9992922b2b97e09d \ + --hash=sha256:26fee36a6d8f2bd9464f3566b95e3e3fb7fd7dad723f775c500aac8204ec98c6 \ + --hash=sha256:2cda65b689aec63b7c76a77f43a08044fa90bbc6ad9849267cedfee9795913f3 \ + --hash=sha256:2ecdfe9736c4a73343f629586016a137a10faed1a29c6dc699d8ab20c2d3cf64 \ + --hash=sha256:319e4367596fb0d52be645b3de1616faf0fadaf28507ce1c7595bebd9b4c402c \ + --hash=sha256:3678be5db330726f19c1949d8ae1b845a02eeb2a2e1d5a8bb8eaa82087ae25c1 \ + --hash=sha256:370b5b86da9bddbe65fa08711f0e8ffdf8b0036558178d1a31dfcb44efcde72a \ + --hash=sha256:37cc955c84ec67c2d11183d372044399342b20a1fa447b7a33040f4889bba318 \ + --hash=sha256:42b180ed1bec58ab9bdc65d406577e0c0fb7241b74b8c032846073c7743c9f86 \ + --hash=sha256:44f1748035c36c939848c935715bde41734d9249ab7b844ff9bfbe984be8952c \ + --hash=sha256:463ee5f3afbfec29cbf5652752c9d1032bdad63daf48bb8cb9970064cc81d5f9 \ + --hash=sha256:49f5336b82e315a33bef1025d247ca08d95719715b29e33f0e9e8cf15ff1dfb6 \ + --hash=sha256:4c7cfc3d42e81cda802f93aa9e92caf79feaa1711426e28ce620560b8aaf5e4d \ + --hash=sha256:50f03d2330a153a9114c2429061137bd323736059f384de8348d7cb1ca1baa15 \ + --hash=sha256:53614f44f36917282a583180e402105bc63d61d1aca067d51cb7f051eb489901 \ + --hash=sha256:5760a831c0f3c6d3229b50ef3fafa4c164ec99d7e8c2237fe144e67a9d33b120 \ + --hash=sha256:59c7df2103052b30b7c76d4fa8251326c9f82689578a912698a127dc1737f43e \ + --hash=sha256:5bef76c4d9329913cef2fe79ce1f4dab98f77fa4887e5f0420ffc9386941de32 \ + --hash=sha256:5f9afdcf701a1aa3c41e0e748c152d2162434d61639a1e5d8523ecf60ae35aea \ + --hash=sha256:60fec380778d75cbb492f14ca974f11f37b41d53c057b9c8ba213315b86e1f84 \ + --hash=sha256:628b654ba555b2ba9111c0936d558b14bfc9d5f57b8c323b02fc846036b38b2f \ + --hash=sha256:638abedb39375f0ddce2de536fc9c976639b2d1b7202d715c2e7a25f0ebfd091 \ + --hash=sha256:67d3adff654dc7f7c7091dd259b3b847fe119c08d0bda61db91e2ea2b61c38c0 \ + --hash=sha256:6859d81243cd09854be9054aca3ecab14a2dee5b3c9f6d7ef12061d478ca0c57 \ + --hash=sha256:68c4699147dded6926a3d2c2f948d435d54d027f69909e0ef3c6587933723ed2 \ + --hash=sha256:6cba7483ba45600346a35c466bde32327b108575022f73c35a0f7170b5a71ae2 \ + --hash=sha256:72d9967fb1f927542cfb5347207fde01b29f25c9bb8cbc7ced280decfa015983 \ + --hash=sha256:72e95184bf5b9a4c08153ed07c16c130ff174835c9a1e6ee2b311be758c8b3ef \ + --hash=sha256:7560fca3e17a6bc876d20cd825d7721c101fa2b1cd0bfa0abf9a2e781e49b37b \ + --hash=sha256:7618b84118ae704f7fa23c4a190bd80fc605671841a4427d5ca14b9b8d9ec1a3 \ + --hash=sha256:7975178f9478ccedcf613332d5d6f37b67c74ef4e2e47e0c965597506b921f04 \ + --hash=sha256:7d9025b185465d9d18679406f6f394850347d5ed2681efc203539d800f36f459 \ + --hash=sha256:89cbeec7e9d5d8773ec4779c64e3cbcbff53d234ca6ad7b1a3736588003bba48 \ + --hash=sha256:8b019c4810903fdea3b230f358b9d27377c0f38454778b607676c9e1b57d14b7 \ + --hash=sha256:8db3a6f3d430ac3dc3793c53fa8e5e665c23ba359484d365a191027ad8b65a30 \ + --hash=sha256:8e63781da85aa8948864970e529af10abc4084a990d30850c41bbdb5f83eee45 \ + --hash=sha256:901635098565773a44f74068639d265f19deaaca47ea77b428fd9bee13a61d87 \ + --hash=sha256:91d32bd1056c0e83a0f90e4ffa213c25096b2d8b9f0e2d172a45f138c7d8c081 \ + --hash=sha256:92c34de04fec7f4ff95f7667d4eb085c4e4db46c31ef44c3d35c38df128430da \ + --hash=sha256:930c19b699dd7e1077eac98967adc2fe5f0b104bd96cc1f26778ab82b31ceb24 \ + --hash=sha256:956729b7dd599020e57133fb95b777e4f81ee069ff0a70e80f6eeac82658972f \ + --hash=sha256:9930f31f603ecc6ea54d5c6dfa299f926ab3e921f72f94babcb02598c32b57c6 \ + --hash=sha256:a1e30a13376db5329570e09b14c8eb36c017909ed7e88591ca3aa81f3c7d6f32 \ + --hash=sha256:a480bd902e327dfcaa52b7dd14fdc71e7aa45d73a3d6e41e028a75891d2823cf \ + --hash=sha256:a687099e085f5162e5b88b3402adb6c2b41046180c015c5075c9504440b6e971 \ + --hash=sha256:a7093767e070269e22e2c5f845e46510304f124c32d2cd249633c0f27eb29d86 \ + --hash=sha256:aae42798ba1da3bc1572b2048fe42e61dd6bacced2b424cb0f5572c5432f79c2 \ + --hash=sha256:acfc8db61c6e919d932448cc7985b85e330c8d745528e12fce6e62d40d268bce \ + --hash=sha256:ad759ba39cd32c2c2247864d02c84ea5883b5f6cc6a4ee0c95602a3dde52268f \ + --hash=sha256:b05ec04132394c20bd6bcb692d557a8eb8ab1bac1646d28e49c67c00907d17c8 \ + --hash=sha256:b886e0f5c72aa4249c609c24b9610a9ca83fd963cbb5066b19302723ea505279 \ + --hash=sha256:ba336bc9107acbc1da2ad30967df7b2db93448ca66538ad86aa1fbb91116f631 \ + --hash=sha256:bfe164a1c72c6be3c5c26753c6c412f81412f4dae0d7d06371e0b396a9cc0fc9 \ + --hash=sha256:c11444984aecd342f0cf160c3320288edeb1763871fbb560ed466654b2a7016c \ + --hash=sha256:c2c659f2106b6d154f118ad1b700e68148c46c59b720f04867b1fc5f26a85060 \ + --hash=sha256:c318a5acb429ca38f632577754235140bbb8c5a27faca1c51b43fbf575596e34 \ + --hash=sha256:c375161b588982be381c43eb7158c250f430793d0f708ce379a0f196164c6778 \ + --hash=sha256:c65d76052561c60e17cb4fa289885ed00a9995d59e97019fac2138bd45142057 \ + --hash=sha256:c84b456ff8525ec3ff09762e32ccc27888d036dcd0ba2883e1db491e164dd725 \ + --hash=sha256:c84d3cb1349936c2b96ca6175b50f5a9518170bffd76464219ee0ea6022a64a7 \ + --hash=sha256:ca0bfc79b27d84fcb7fa09339b2ee39077896738d9a30ff99c0332376e985072 \ + --hash=sha256:ca304402ea66d58f99c05aa3d7a6052faea61e5a8313b94f6bc36fbf27960e2d \ + --hash=sha256:caf0df8657277e32671aa8a4d3cc05f2050ab19d9b49447f2265304168e9032c \ + --hash=sha256:cb3c6bc6e599e46a26ad559ad5dec260ffdf705663cc9b894033d64a69314e86 \ + --hash=sha256:cce4d1a97a7eb2253b5d3f29f4a478d8c37ba0303ea34024eb9e65506d4209f8 \ + --hash=sha256:d091c62cb7abbd32e527a85c41f7c8eb4526a926251891fc4ecbe5f974142ffb \ + --hash=sha256:d3a6330c9f1deda22873e8b4ac849cc06d3ff33d60b3217ac0bb397b541e1509 \ + --hash=sha256:d49567a2754e9991c05c2b5a7e6650b56e24365b7cab504558e58033dcf0edc4 \ + --hash=sha256:d72d25c57a9c814240802d188ff0a808b701e2dd2bf1c64721c7088ceeeb1ed7 \ + --hash=sha256:db96cf092d86d4cb543daa9148e299011e0a40770380bb78333b9fd700586fcb \ + --hash=sha256:df4f058e96e8b467b7742e5dba7564255cd482d3c1e6cf81f8cb683bb0433340 \ + --hash=sha256:e3b6082e9532309727273443c8943bb9558d52e36788b246aa278bda7c642116 \ + --hash=sha256:e448b2be0430ab839cf7954715c39d6f34ff6cf2b49393f336283b7a59f485af \ + --hash=sha256:e8984114fd83ed3913d89526c992395920930c9620a2feee61faf035f41d7b9a \ + --hash=sha256:e9f27399b8d50c5d3f08f0aae961bcc66a1dead1cd0ae9401e4c2a43a623322a \ + --hash=sha256:ea3b3f8908a9a5b9d6fc632b5f012ece7240031c44c6d4764809f33736534166 \ + --hash=sha256:ebefbc26ccff5e96ae7d40772172e7310174f9aa3683d2870a1882313ec3a4d5 \ + --hash=sha256:ec8f46d533092d8e20bc742c47918cbe24b8641dbfbbcb83177c5de3c9d4decb \ + --hash=sha256:ee6b63aecf929a7bcf885bdc8a8aec96c43bc4442f63fe8c6d48f24fc992b05b \ + --hash=sha256:f475d5eda41d2ed51ca775a07c80529a923dd759fcff7abf03ccdd83d9f7564e \ + --hash=sha256:f522f28c88a0d5b2f9e895cf405dd594cd518e99d61905406aec74d30eb6383b \ + --hash=sha256:f77371b5030e53f8bf92197640af437539e3bba1bc8342b97888c8e26567bfdc \ + --hash=sha256:f8cf2fcdc2368df4317e05571e33810eeed24cd594acc9dfc9788b21dac6b3a8 \ + --hash=sha256:fe2ea1177146a7ab345ab61e90a490eeea25d5f063e1cb9d4eb1425b169b64d7 \ + --hash=sha256:fee553657dcdb7e73df8823c49e8611457ba46e9d7026b7e9c44820c08c327c3 \ + --hash=sha256:ff66577ae55114f7d0f6aa0d4d335f27cae96bf245962a745b718ec887bbe7eb + # via transformers +toml==0.10.2 \ + --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ + --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f + # via + # tensorflow-datasets + # tfds-nightly +toolz==0.12.0 \ + --hash=sha256:2059bd4148deb1884bb0eb770a3cde70e7f954cfbbdc2285f1f2de01fd21eb6f \ + --hash=sha256:88c570861c440ee3f2f6037c4654613228ff40c93a6c25e0eba70d17282c6194 + # via chex +tqdm==4.66.1 \ + --hash=sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386 \ + --hash=sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7 + # via + # etils + # gdown + # huggingface-hub + # nltk + # tensorflow-datasets + # tfds-nightly + # transformers +transformers==4.34.0 \ + --hash=sha256:3f0187183a7f22c51ecbbc9eac5145df666c5b86bec6feed10e11f0363f3a1f9 \ + --hash=sha256:cc2ae61bfbfaa45337fd9017326669fc60e4f55125f589d50da47819e3d6f504 + # via t5 +typing-extensions==4.5.0 \ + --hash=sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb \ + --hash=sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4 + # via + # chex + # etils + # flax + # huggingface-hub + # tensorflow +tzdata==2023.3 \ + --hash=sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a \ + --hash=sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda + # via pandas +urllib3==2.0.6 \ + --hash=sha256:7a7c7003b000adf9e7ca2a377c9688bbc54ed41b985789ed576570342a375cd2 \ + --hash=sha256:b19e1a85d206b56d7df1d5e683df4a7725252a964e3993648dd0fb5a1c157564 + # via requests +werkzeug==3.0.0 \ + --hash=sha256:3ffff4dcc32db52ef3cc94dff3000a3c2846890f3a5a51800a27b909c5e770f0 \ + --hash=sha256:cbb2600f7eabe51dbc0502f58be0b3e1b96b893b05695ea2b35b43d4de2d9962 + # via tensorboard +wheel==0.41.2 \ + --hash=sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985 \ + --hash=sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8 + # via + # astunparse + # tensorboard +wrapt==1.15.0 \ + --hash=sha256:02fce1852f755f44f95af51f69d22e45080102e9d00258053b79367d07af39c0 \ + --hash=sha256:077ff0d1f9d9e4ce6476c1a924a3332452c1406e59d90a2cf24aeb29eeac9420 \ + --hash=sha256:078e2a1a86544e644a68422f881c48b84fef6d18f8c7a957ffd3f2e0a74a0d4a \ + --hash=sha256:0970ddb69bba00670e58955f8019bec4a42d1785db3faa043c33d81de2bf843c \ + --hash=sha256:1286eb30261894e4c70d124d44b7fd07825340869945c79d05bda53a40caa079 \ + --hash=sha256:21f6d9a0d5b3a207cdf7acf8e58d7d13d463e639f0c7e01d82cdb671e6cb7923 \ + --hash=sha256:230ae493696a371f1dbffaad3dafbb742a4d27a0afd2b1aecebe52b740167e7f \ + --hash=sha256:26458da5653aa5b3d8dc8b24192f574a58984c749401f98fff994d41d3f08da1 \ + --hash=sha256:2cf56d0e237280baed46f0b5316661da892565ff58309d4d2ed7dba763d984b8 \ + --hash=sha256:2e51de54d4fb8fb50d6ee8327f9828306a959ae394d3e01a1ba8b2f937747d86 \ + --hash=sha256:2fbfbca668dd15b744418265a9607baa970c347eefd0db6a518aaf0cfbd153c0 \ + --hash=sha256:38adf7198f8f154502883242f9fe7333ab05a5b02de7d83aa2d88ea621f13364 \ + --hash=sha256:3a8564f283394634a7a7054b7983e47dbf39c07712d7b177b37e03f2467a024e \ + --hash=sha256:3abbe948c3cbde2689370a262a8d04e32ec2dd4f27103669a45c6929bcdbfe7c \ + --hash=sha256:3bbe623731d03b186b3d6b0d6f51865bf598587c38d6f7b0be2e27414f7f214e \ + --hash=sha256:40737a081d7497efea35ab9304b829b857f21558acfc7b3272f908d33b0d9d4c \ + --hash=sha256:41d07d029dd4157ae27beab04d22b8e261eddfc6ecd64ff7000b10dc8b3a5727 \ + --hash=sha256:46ed616d5fb42f98630ed70c3529541408166c22cdfd4540b88d5f21006b0eff \ + --hash=sha256:493d389a2b63c88ad56cdc35d0fa5752daac56ca755805b1b0c530f785767d5e \ + --hash=sha256:4ff0d20f2e670800d3ed2b220d40984162089a6e2c9646fdb09b85e6f9a8fc29 \ + --hash=sha256:54accd4b8bc202966bafafd16e69da9d5640ff92389d33d28555c5fd4f25ccb7 \ + --hash=sha256:56374914b132c702aa9aa9959c550004b8847148f95e1b824772d453ac204a72 \ + --hash=sha256:578383d740457fa790fdf85e6d346fda1416a40549fe8db08e5e9bd281c6a475 \ + --hash=sha256:58d7a75d731e8c63614222bcb21dd992b4ab01a399f1f09dd82af17bbfc2368a \ + --hash=sha256:5c5aa28df055697d7c37d2099a7bc09f559d5053c3349b1ad0c39000e611d317 \ + --hash=sha256:5fc8e02f5984a55d2c653f5fea93531e9836abbd84342c1d1e17abc4a15084c2 \ + --hash=sha256:63424c681923b9f3bfbc5e3205aafe790904053d42ddcc08542181a30a7a51bd \ + --hash=sha256:64b1df0f83706b4ef4cfb4fb0e4c2669100fd7ecacfb59e091fad300d4e04640 \ + --hash=sha256:74934ebd71950e3db69960a7da29204f89624dde411afbfb3b4858c1409b1e98 \ + --hash=sha256:75669d77bb2c071333417617a235324a1618dba66f82a750362eccbe5b61d248 \ + --hash=sha256:75760a47c06b5974aa5e01949bf7e66d2af4d08cb8c1d6516af5e39595397f5e \ + --hash=sha256:76407ab327158c510f44ded207e2f76b657303e17cb7a572ffe2f5a8a48aa04d \ + --hash=sha256:76e9c727a874b4856d11a32fb0b389afc61ce8aaf281ada613713ddeadd1cfec \ + --hash=sha256:77d4c1b881076c3ba173484dfa53d3582c1c8ff1f914c6461ab70c8428b796c1 \ + --hash=sha256:780c82a41dc493b62fc5884fb1d3a3b81106642c5c5c78d6a0d4cbe96d62ba7e \ + --hash=sha256:7dc0713bf81287a00516ef43137273b23ee414fe41a3c14be10dd95ed98a2df9 \ + --hash=sha256:7eebcdbe3677e58dd4c0e03b4f2cfa346ed4049687d839adad68cc38bb559c92 \ + --hash=sha256:896689fddba4f23ef7c718279e42f8834041a21342d95e56922e1c10c0cc7afb \ + --hash=sha256:96177eb5645b1c6985f5c11d03fc2dbda9ad24ec0f3a46dcce91445747e15094 \ + --hash=sha256:96e25c8603a155559231c19c0349245eeb4ac0096fe3c1d0be5c47e075bd4f46 \ + --hash=sha256:9d37ac69edc5614b90516807de32d08cb8e7b12260a285ee330955604ed9dd29 \ + --hash=sha256:9ed6aa0726b9b60911f4aed8ec5b8dd7bf3491476015819f56473ffaef8959bd \ + --hash=sha256:a487f72a25904e2b4bbc0817ce7a8de94363bd7e79890510174da9d901c38705 \ + --hash=sha256:a4cbb9ff5795cd66f0066bdf5947f170f5d63a9274f99bdbca02fd973adcf2a8 \ + --hash=sha256:a74d56552ddbde46c246b5b89199cb3fd182f9c346c784e1a93e4dc3f5ec9975 \ + --hash=sha256:a89ce3fd220ff144bd9d54da333ec0de0399b52c9ac3d2ce34b569cf1a5748fb \ + --hash=sha256:abd52a09d03adf9c763d706df707c343293d5d106aea53483e0ec8d9e310ad5e \ + --hash=sha256:abd8f36c99512755b8456047b7be10372fca271bf1467a1caa88db991e7c421b \ + --hash=sha256:af5bd9ccb188f6a5fdda9f1f09d9f4c86cc8a539bd48a0bfdc97723970348418 \ + --hash=sha256:b02f21c1e2074943312d03d243ac4388319f2456576b2c6023041c4d57cd7019 \ + --hash=sha256:b06fa97478a5f478fb05e1980980a7cdf2712015493b44d0c87606c1513ed5b1 \ + --hash=sha256:b0724f05c396b0a4c36a3226c31648385deb6a65d8992644c12a4963c70326ba \ + --hash=sha256:b130fe77361d6771ecf5a219d8e0817d61b236b7d8b37cc045172e574ed219e6 \ + --hash=sha256:b56d5519e470d3f2fe4aa7585f0632b060d532d0696c5bdfb5e8319e1d0f69a2 \ + --hash=sha256:b67b819628e3b748fd3c2192c15fb951f549d0f47c0449af0764d7647302fda3 \ + --hash=sha256:ba1711cda2d30634a7e452fc79eabcadaffedf241ff206db2ee93dd2c89a60e7 \ + --hash=sha256:bbeccb1aa40ab88cd29e6c7d8585582c99548f55f9b2581dfc5ba68c59a85752 \ + --hash=sha256:bd84395aab8e4d36263cd1b9308cd504f6cf713b7d6d3ce25ea55670baec5416 \ + --hash=sha256:c99f4309f5145b93eca6e35ac1a988f0dc0a7ccf9ccdcd78d3c0adf57224e62f \ + --hash=sha256:ca1cccf838cd28d5a0883b342474c630ac48cac5df0ee6eacc9c7290f76b11c1 \ + --hash=sha256:cd525e0e52a5ff16653a3fc9e3dd827981917d34996600bbc34c05d048ca35cc \ + --hash=sha256:cdb4f085756c96a3af04e6eca7f08b1345e94b53af8921b25c72f096e704e145 \ + --hash=sha256:ce42618f67741d4697684e501ef02f29e758a123aa2d669e2d964ff734ee00ee \ + --hash=sha256:d06730c6aed78cee4126234cf2d071e01b44b915e725a6cb439a879ec9754a3a \ + --hash=sha256:d5fe3e099cf07d0fb5a1e23d399e5d4d1ca3e6dfcbe5c8570ccff3e9208274f7 \ + --hash=sha256:d6bcbfc99f55655c3d93feb7ef3800bd5bbe963a755687cbf1f490a71fb7794b \ + --hash=sha256:d787272ed958a05b2c86311d3a4135d3c2aeea4fc655705f074130aa57d71653 \ + --hash=sha256:e169e957c33576f47e21864cf3fc9ff47c223a4ebca8960079b8bd36cb014fd0 \ + --hash=sha256:e20076a211cd6f9b44a6be58f7eeafa7ab5720eb796975d0c03f05b47d89eb90 \ + --hash=sha256:e826aadda3cae59295b95343db8f3d965fb31059da7de01ee8d1c40a60398b29 \ + --hash=sha256:eef4d64c650f33347c1f9266fa5ae001440b232ad9b98f1f43dfe7a79435c0a6 \ + --hash=sha256:f2e69b3ed24544b0d3dbe2c5c0ba5153ce50dcebb576fdc4696d52aa22db6034 \ + --hash=sha256:f87ec75864c37c4c6cb908d282e1969e79763e0d9becdfe9fe5473b7bb1e5f09 \ + --hash=sha256:fbec11614dba0424ca72f4e8ba3c420dba07b4a7c206c8c8e4e73f2e98f4c559 \ + --hash=sha256:fd69666217b62fa5d7c6aa88e507493a34dec4fa20c5bd925e4bc12fce586639 + # via + # tensorflow + # tensorflow-datasets + # tfds-nightly +zipp==3.17.0 \ + --hash=sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31 \ + --hash=sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0 + # via etils + +# WARNING: The following packages were not pinned, but pip requires them to be +# pinned when the requirements file includes hashes. Consider using the --allow-unsafe flag. +# setuptools diff --git a/rules.txt b/rules.txt new file mode 100644 index 0000000..f329f76 --- /dev/null +++ b/rules.txt @@ -0,0 +1,43 @@ +perp A B C D, perp C D E F, ncoll A B E => para A B E F +cong O A O B, cong O B O C, cong O C O D => cyclic A B C D +eqangle A B P Q C D P Q => para A B C D +cyclic A B P Q => eqangle P A P B Q A Q B +eqangle6 P A P B Q A Q B, ncoll P Q A B => cyclic A B P Q +cyclic A B C P Q R, eqangle C A C B R P R Q => cong A B P Q +midp E A B, midp F A C => para E F B C +para A B C D, coll O A C, coll O B D => eqratio3 A B C D O O +perp A B C D, perp E F G H, npara A B E F => eqangle A B E F C D G H +eqangle a b c d m n p q, eqangle c d e f p q r u => eqangle a b e f m n r u +eqratio a b c d m n p q, eqratio c d e f p q r u => eqratio a b e f m n r u +eqratio6 d b d c a b a c, coll d b c, ncoll a b c => eqangle6 a b a d a d a c +eqangle6 a b a d a d a c, coll d b c, ncoll a b c => eqratio6 d b d c a b a c +cong O A O B, ncoll O A B => eqangle O A A B A B O B +eqangle6 A O A B B A B O, ncoll O A B => cong O A O B +circle O A B C, perp O A A X => eqangle A X A B C A C B +circle O A B C, eqangle A X A B C A C B => perp O A A X +circle O A B C, midp M B C => eqangle A B A C O B O M +circle O A B C, coll M B C, eqangle A B A C O B O M => midp M B C +perp A B B C, midp M A C => cong A M B M +circle O A B C, coll O A C => perp A B B C +cyclic A B C D, para A B C D => eqangle A D C D C D C B +midp M A B, perp O M A B => cong O A O B +cong A P B P, cong A Q B Q => perp A B P Q +cong A P B P, cong A Q B Q, cyclic A B P Q => perp P A A Q +midp M A B, midp M C D => para A C B D +midp M A B, para A C B D, para A D B C => midp M C D +eqratio O A A C O B B D, coll O A C, coll O B D, ncoll A B C, sameside A O C B O D => para A B C D +para A B A C => coll A B C +midp M A B, midp N C D => eqratio M A A B N C C D +eqangle A B P Q C D U V, perp P Q U V => perp A B C D +eqratio A B P Q C D U V, cong P Q U V => cong A B C D +cong A B P Q, cong B C Q R, cong C A R P, ncoll A B C => contri* A B C P Q R +cong A B P Q, cong B C Q R, eqangle6 B A B C Q P Q R, ncoll A B C => contri* A B C P Q R +eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C => simtri A B C P Q R +eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C => simtri2 A B C P Q R +eqangle6 B A B C Q P Q R, eqangle6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri A B C P Q R +eqangle6 B A B C Q R Q P, eqangle6 C A C B R Q R P, ncoll A B C, cong A B P Q => contri2 A B C P Q R +eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C => simtri* A B C P Q R +eqratio6 B A B C Q P Q R, eqangle6 B A B C Q P Q R, ncoll A B C => simtri* A B C P Q R +eqratio6 B A B C Q P Q R, eqratio6 C A C B R P R Q, ncoll A B C, cong A B P Q => contri* A B C P Q R +para a b c d, coll m a d, coll n b c, eqratio6 m a m d n b n c, sameside m a d n b c => para m n a b +para a b c d, coll m a d, coll n b c, para m n a b => eqratio6 m a m d n b n c diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..b162f79 --- /dev/null +++ b/run.sh @@ -0,0 +1,72 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# !/bin/bash +set -e +set -x + +virtualenv -p python3 . +source ./bin/activate + +pip install --require-hashes -r requirements.txt + +gdown --folder https://bit.ly/alphageometry +DATA=ag_ckpt_vocab + +MELIAD_PATH=meliad_lib/meliad +mkdir -p $MELIAD_PATH +git clone https://github.com/google-research/meliad $MELIAD_PATH +PYTHONPATH=$PYTHONPATH:$MELIAD_PATH + +DDAR_ARGS=( + --defs_file=$(pwd)/defs.txt \ + --rules_file=$(pwd)/rules.txt \ +); + +BATCH_SIZE=2 +BEAM_SIZE=2 +DEPTH=2 + +SEARCH_ARGS=( + --beam_size=$BEAM_SIZE + --search_depth=$DEPTH +) + +LM_ARGS=( + --ckpt_path=$DATA \ + --vocab_path=$DATA/geometry.757.model \ + --gin_search_paths=$MELIAD_PATH/transformer/configs \ + --gin_file=base_htrans.gin \ + --gin_file=size/medium_150M.gin \ + --gin_file=options/positions_t5.gin \ + --gin_file=options/lr_cosine_decay.gin \ + --gin_file=options/seq_1024_nocache.gin \ + --gin_file=geometry_150M_generate.gin \ + --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \ + --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \ + --gin_param=TransformerTaskConfig.sequence_length=128 \ + --gin_param=Trainer.restore_state_variables=False +); + +echo $PYTHONPATH + +python -m alphageometry \ +--alsologtostderr \ +--problems_file=$(pwd)/examples.txt \ +--problem_name=orthocenter \ +--mode=alphageometry \ +"${DDAR_ARGS[@]}" \ +"${SEARCH_ARGS[@]}" \ +"${LM_ARGS[@]}" diff --git a/run_tests.sh b/run_tests.sh new file mode 100644 index 0000000..9926bed --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,30 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +DATA=ag_ckpt_vocab +MELIAD_PATH=meliad_lib/meliad +PYTHONPATH=$PYTHONPATH:$MELIAD_PATH + +python problem_test.py +python geometry_test.py +python graph_utils_test.py +python numericals_test.py +python graph_test.py +python dd_test.py +python ar_test.py +python ddar_test.py +python trace_back_test.py +python alphageometry_test.py +python lm_inference_test.py --meliad_path=$MELIAD_PATH --data_path=$DATA diff --git a/trace_back.py b/trace_back.py new file mode 100644 index 0000000..2543918 --- /dev/null +++ b/trace_back.py @@ -0,0 +1,374 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implements DAG-level traceback.""" + +from typing import Any + +import geometry as gm +import pretty as pt +import problem + + +pretty = pt.pretty + + +def point_levels( + setup: list[problem.Dependency], existing_points: list[gm.Point] +) -> list[tuple[set[gm.Point], list[problem.Dependency]]]: + """Reformat setup into levels of point constructions.""" + levels = [] + for con in setup: + plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)]) + + while len(levels) - 1 < plevel: + levels.append((set(), [])) + + for p in con.args: + if not isinstance(p, gm.Point): + continue + if existing_points and p in existing_points: + continue + + levels[p.plevel][0].add(p) + + cons = levels[plevel][1] + cons.append(con) + + return [(p, c) for p, c in levels if p or c] + + +def point_log( + setup: list[problem.Dependency], + ref_id: dict[tuple[str, ...], int], + existing_points=list[gm.Point], +) -> list[tuple[list[gm.Point], list[problem.Dependency]]]: + """Reformat setup into groups of point constructions.""" + log = [] + + levels = point_levels(setup, existing_points) + + for points, cons in levels: + for con in cons: + if con.hashed() not in ref_id: + ref_id[con.hashed()] = len(ref_id) + + log.append((points, cons)) + + return log + + +def setup_to_levels( + setup: list[problem.Dependency], +) -> list[list[problem.Dependency]]: + """Reformat setup into levels of point constructions.""" + levels = [] + for d in setup: + plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)]) + while len(levels) - 1 < plevel: + levels.append([]) + + levels[plevel].append(d) + + levels = [lvl for lvl in levels if lvl] + return levels + + +def separate_dependency_difference( + query: problem.Dependency, + log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], +) -> tuple[ + list[tuple[list[problem.Dependency], list[problem.Dependency]]], + list[problem.Dependency], + list[problem.Dependency], + set[gm.Point], + set[gm.Point], +]: + """Identify and separate the dependency difference.""" + setup = [] + log_, log = log, [] + for prems, cons in log_: + if not prems: + setup.extend(cons) + continue + cons_ = [] + for con in cons: + if con.rule_name == 'c0': + setup.append(con) + else: + cons_.append(con) + if not cons_: + continue + + prems = [p for p in prems if p.name != 'ind'] + log.append((prems, cons_)) + + points = set(query.args) + queue = list(query.args) + i = 0 + while i < len(queue): + q = queue[i] + i += 1 + if not isinstance(q, gm.Point): + continue + for p in q.rely_on: + if p not in points: + points.add(p) + queue.append(p) + + setup_, setup, aux_setup, aux_points = setup, [], [], set() + for con in setup_: + if con.name == 'ind': + continue + elif any([p not in points for p in con.args if isinstance(p, gm.Point)]): + aux_setup.append(con) + aux_points.update( + [p for p in con.args if isinstance(p, gm.Point) and p not in points] + ) + else: + setup.append(con) + + return log, setup, aux_setup, points, aux_points + + +def recursive_traceback( + query: problem.Dependency, +) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]: + """Recursively traceback from the query, i.e. the conclusion.""" + visited = set() + log = [] + stack = [] + + def read(q: problem.Dependency) -> None: + q = q.remove_loop() + hashed = q.hashed() + if hashed in visited: + return + + if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']: + return + + nonlocal stack + + stack.append(hashed) + prems = [] + + if q.rule_name != problem.CONSTRUCTION_RULE: + all_deps = [] + dep_names = set() + for d in q.why: + if d.hashed() in dep_names: + continue + dep_names.add(d.hashed()) + all_deps.append(d) + + for d in all_deps: + h = d.hashed() + if h not in visited: + read(d) + if h in visited: + prems.append(d) + + visited.add(hashed) + hashs = sorted([d.hashed() for d in prems]) + found = False + for ps, qs in log: + if sorted([d.hashed() for d in ps]) == hashs: + qs += [q] + found = True + break + if not found: + log.append((prems, [q])) + + stack.pop(-1) + + read(query) + + # post process log: separate multi-conclusion lines + log_, log = log, [] + for ps, qs in log_: + for q in qs: + log.append((ps, [q])) + + return log + + +def collx_to_coll_setup( + setup: list[problem.Dependency], +) -> list[problem.Dependency]: + """Convert collx to coll in setups.""" + result = [] + for level in setup_to_levels(setup): + hashs = set() + for dep in level: + if dep.name == 'collx': + dep.name = 'coll' + dep.args = list(set(dep.args)) + + if dep.hashed() in hashs: + continue + hashs.add(dep.hashed()) + result.append(dep) + + return result + + +def collx_to_coll( + setup: list[problem.Dependency], + aux_setup: list[problem.Dependency], + log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], +) -> tuple[ + list[problem.Dependency], + list[problem.Dependency], + list[tuple[list[problem.Dependency], list[problem.Dependency]]], +]: + """Convert collx to coll and dedup.""" + setup = collx_to_coll_setup(setup) + aux_setup = collx_to_coll_setup(aux_setup) + + con_set = set([p.hashed() for p in setup + aux_setup]) + log_, log = log, [] + for prems, cons in log_: + prem_set = set() + prems_, prems = prems, [] + for p in prems_: + if p.name == 'collx': + p.name = 'coll' + p.args = list(set(p.args)) + if p.hashed() in prem_set: + continue + prem_set.add(p.hashed()) + prems.append(p) + + cons_, cons = cons, [] + for c in cons_: + if c.name == 'collx': + c.name = 'coll' + c.args = list(set(c.args)) + if c.hashed() in con_set: + continue + con_set.add(c.hashed()) + cons.append(c) + + if not cons or not prems: + continue + + log.append((prems, cons)) + + return setup, aux_setup, log + + +def get_logs( + query: problem.Dependency, g: Any, merge_trivials: bool = False +) -> tuple[ + list[problem.Dependency], + list[problem.Dependency], + list[tuple[list[problem.Dependency], list[problem.Dependency]]], + set[gm.Point], +]: + """Given a DAG and conclusion N, return the premise, aux, proof.""" + query = query.why_me_or_cache(g, query.level) + log = recursive_traceback(query) + log, setup, aux_setup, setup_points, _ = separate_dependency_difference( + query, log + ) + + setup, aux_setup, log = collx_to_coll(setup, aux_setup, log) + + setup, aux_setup, log = shorten_and_shave( + setup, aux_setup, log, merge_trivials + ) + + return setup, aux_setup, log, setup_points + + +def shorten_and_shave( + setup: list[problem.Dependency], + aux_setup: list[problem.Dependency], + log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], + merge_trivials: bool = False, +) -> tuple[ + list[problem.Dependency], + list[problem.Dependency], + list[tuple[list[problem.Dependency], list[problem.Dependency]]], +]: + """Shorten the proof by removing unused predicates.""" + log, _ = shorten_proof(log, merge_trivials=merge_trivials) + + all_prems = sum([list(prems) for prems, _ in log], []) + all_prems = set([p.hashed() for p in all_prems]) + setup = [d for d in setup if d.hashed() in all_prems] + aux_setup = [d for d in aux_setup if d.hashed() in all_prems] + return setup, aux_setup, log + + +def join_prems( + con: problem.Dependency, + con2prems: dict[tuple[str, ...], list[problem.Dependency]], + expanded: set[tuple[str, ...]], +) -> list[problem.Dependency]: + """Join proof steps with the same premises.""" + h = con.hashed() + if h in expanded or h not in con2prems: + return [con] + + result = [] + for p in con2prems[h]: + result += join_prems(p, con2prems, expanded) + return result + + +def shorten_proof( + log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], + merge_trivials: bool = False, +) -> tuple[ + list[tuple[list[problem.Dependency], list[problem.Dependency]]], + dict[tuple[str, ...], list[problem.Dependency]], +]: + """Join multiple trivials proof steps into one.""" + pops = set() + con2prem = {} + for prems, cons in log: + assert len(cons) == 1 + con = cons[0] + if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison + con2prem[con.hashed()] = prems + elif not merge_trivials: + # except for the ones that are premises to non-trivial steps. + pops.update({p.hashed() for p in prems}) + + for p in pops: + if p in con2prem: + con2prem.pop(p) + + expanded = set() + log2 = [] + for i, (prems, cons) in enumerate(log): + con = cons[0] + if i < len(log) - 1 and con.hashed() in con2prem: + continue + + hashs = set() + new_prems = [] + + for p in sum([join_prems(p, con2prem, expanded) for p in prems], []): + if p.hashed() not in hashs: + new_prems.append(p) + hashs.add(p.hashed()) + + log2 += [(new_prems, [con])] + expanded.add(con.hashed()) + + return log2, con2prem diff --git a/trace_back_test.py b/trace_back_test.py new file mode 100644 index 0000000..fb402ef --- /dev/null +++ b/trace_back_test.py @@ -0,0 +1,61 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit testing for the trace_back code.""" + +import unittest + +from absl.testing import absltest +import ddar +import graph as gh +import problem as pr +import trace_back as tb + + +class TracebackTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True) + cls.rules = pr.Theorem.from_txt_file('rules.txt', to_dict=True) + + def test_orthocenter_dependency_difference(self): + txt = 'a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c' # pylint: disable=line-too-long + p = pr.Problem.from_txt(txt) + g, _ = gh.Graph.build_problem(p, TracebackTest.defs) + + ddar.solve(g, TracebackTest.rules, p) + + goal_args = g.names2nodes(p.goal.args) + query = pr.Dependency(p.goal.name, goal_args, None, None) + + setup, aux, _, _ = tb.get_logs(query, g, merge_trivials=False) + + # Convert each predicates to its hash string: + setup = [p.hashed() for p in setup] + aux = [p.hashed() for p in aux] + + self.assertCountEqual( + setup, [('perp', 'a', 'c', 'b', 'd'), ('perp', 'a', 'b', 'c', 'd')] + ) + + self.assertCountEqual( + aux, [('coll', 'a', 'c', 'e'), ('coll', 'b', 'd', 'e')] + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/transformer_layer.py b/transformer_layer.py new file mode 100644 index 0000000..9f97c5b --- /dev/null +++ b/transformer_layer.py @@ -0,0 +1,527 @@ +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A single transformer layer in inference mode. + +Modified +https://github.com/google-research/meliad/blob/main/transformer/transformer_layer.py +To accommodate sequence packing + kv cache + relative position during test time. +""" + +from typing import Callable, Mapping, NewType, Optional, Tuple + +from absl import logging +import gin +import jax +import jax.numpy as jnp +from transformer import attention +from transformer import nn_components +from transformer import position +from transformer import transformer_layer + + +Array = jnp.ndarray +DecoderState = NewType("DecoderState", Mapping[str, Array]) +WindowState = Optional[Tuple[attention.KVITuple, Array]] + + +@jax.vmap +def update_slice_in_dim_1(array: Array, update: Array, idx: Array) -> Array: + """Update a stored keys/values slice for different-lengthed seqs in batch.""" + return jax.lax.dynamic_update_slice_in_dim(array, update, idx, axis=0) + + +def slice_in_dim_1(window_length: int) -> Callable[[Array, Array], Array]: + @jax.vmap + def fn(array: Array, idx: Array) -> Array: + return jax.lax.dynamic_slice_in_dim(array, idx, window_length, axis=0) + + return fn + + +@gin.configurable +class TransformerLayerGenerate(transformer_layer.TransformerLayer): + """Full transformer layer, with attention.""" + + def _next_decoder_state( + self, decoder_state: DecoderState, keys: Array, values: Array + ) -> Tuple[DecoderState, Array, Array]: + """Compute the next decoder state, and return keys,values to attend to. + + The keys,values returned from this function are drawn from the prior + decoding state, and comprise a full window of local context. + + Args: + decoder_state: The current decoder state, initially created using + init_decoder_state(). + keys: The key for the current token, of shape (batch_size, 1, dim) + values: The value for the current token of shape (batch_size, 1, dim) + + Returns: + (next_decoder_state, + window of keys of shape (batch_size, window_length, dim), + window of values of shape (batch_size, window_length, dim)) + """ + + assert keys.shape[1] == 1 # single-token autoregressive decoding. + + # Unpack decoder_state + stored_keys = decoder_state["keys"] + stored_values = decoder_state["values"] + curr_index = decoder_state["current_index"] + + # Slice to get window_length-sized chunk of previous keys,values. + out_decoder_state = {} + curr_win_index = curr_index - self.window_length + + # out_keys = jax.lax.dynamic_slice_in_dim( + # stored_keys, curr_win_index, self.window_length, axis=1) + out_keys = slice_in_dim_1(self.window_length)(stored_keys, curr_win_index) + + # out_values = jax.lax.dynamic_slice_in_dim( + # stored_values, curr_win_index, self.window_length, axis=1) + out_values = slice_in_dim_1(self.window_length)( + stored_values, curr_win_index + ) + + # Write current keys,values to stored keys, values. + # stored_keys = jax.lax.dynamic_update_slice_in_dim( + # stored_keys, keys, curr_index, axis=1) + stored_keys = update_slice_in_dim_1(stored_keys, keys, curr_index) + # stored_values = jax.lax.dynamic_update_slice_in_dim( + # stored_values, values, curr_index, axis=1) + stored_values = update_slice_in_dim_1(stored_values, values, curr_index) + curr_index = curr_index + 1 + + # Pack a new decoder_state object. + out_decoder_state["keys"] = stored_keys + out_decoder_state["values"] = stored_values + out_decoder_state["current_index"] = curr_index + out_decoder_state["relative_position_bias"] = decoder_state[ + "relative_position_bias" + ] + out_decoder_state["recurrent_kvq"] = decoder_state["recurrent_kvq"] + + return (DecoderState(out_decoder_state), out_keys, out_values) + + def __call__( + self, + xs: Array, + start_of_sequence: Array, + *, + importance: Optional[Array] = None, + cross_attention_kv: Optional[Tuple[Array, Array]] = None, + window_state: Optional[WindowState] = None, + decoder_state: Optional[DecoderState] = None, + ): + """Computes attention over a sequence of inputs. + + Args: + xs: input sequence of shape (batch_size, sequence_length, num_hidden) + start_of_sequence: An input array of shape (batch_size) --- The following + must be passed by keyword only. --- + importance: Array of shape (batch_size, sequence_length). An importance + bias for attention. + cross_attention_kv: Keys and values from encoder for cross-attention. + window_state: State object which contains context from the prior window + when using a transformer-XL or sliding window. Initially created with + load_window_state(). + decoder_state: State object for autoregressive decoding, initially created + with from init_decoder_state(). + + Returns: + (ys: outputs of shape (batch_size, sequence_length, num_hidden), + importance_score: importance score for the next layer, + next_window_state: state to pass to the next window, + next_decoder_state: next decoder state for autoregressive decoding, + viz_dict: dictionary of visualizations + ) + """ + + xs = jnp.asarray(xs, dtype=self.dtype) + logging.info("tlayer: recurrent = %r", self.recurrent_attention) + logging.info("tlayer: compute_importance = %r", self.compute_importance) + + is_training = self.mode == "train" + + # Compute keys, values and queries. + # --------------------------------- + logging.info("tlayer: compute keys,values,queries.") + (keys, values, queries, queries2) = self.tbase.kvq(xs) + attention_scale_factors = self.tbase.attention_scale_factors() + (_, sequence_length, num_heads, _) = queries.shape # (b, k, h, d) + + # Get biases and masks that are shared across windows. + # ---------------------------------------------------- + if decoder_state is not None: + logging.info("tlayer: using autoregressive decoder.") + # When decoding, prior keys,values are loaded from the decoder state. + # Other values are precomputed, and loaded from the decoder state. + # The decoder state will be updated with the current token. + assert window_state is None + + prev_kvi = None + recurrent_state = None # Use precomputed recurrent_kvq. + cross_attention_kv = None + rel_position_bias = decoder_state["relative_position_bias"] + causal_mask = None + dropout_multiplier = None + + # Reuse cached recurrent keys,values for each token. + cached_recurrent_kvq = decoder_state["recurrent_kvq"] + if cached_recurrent_kvq is not None: + assert cross_attention_kv is None + cross_attention_kv = (cached_recurrent_kvq[0], cached_recurrent_kvq[1]) + del cached_recurrent_kvq + + # Get a full window of keys,values and update decoder state. + (decoder_state, keys, values) = self._next_decoder_state( + decoder_state, keys, values + ) + + # Each query attends to window_length prior keys. + assert keys.shape[1] == self.window_length + kq_relative_offset = self.window_length + + if not self.use_long_xl_architecture: + kqpos = position.relative_positions( + 1, self.window_length, offset=0 + ) # 2D mask + current_idx = decoder_state["current_index"] + + # add (batch, heads) dims for kqpos + kqpos = jnp.expand_dims(kqpos, axis=(0, 1)) + kqpos = jnp.tile(kqpos, (1, self.num_heads, 1, 1)) + + # add (_, heads, _) dim for current_idx + current_idx = jnp.expand_dims(current_idx, axis=(1, 2, 3)) + + causal_mask = kqpos > self.window_length * 2 - current_idx + else: + logging.info("tlayer: windowed attention.") + # When training, attention is done using windows or chunks, and prior + # context (e.g. keys,values from the previous window) is stored in the + # window_state object. + (prev_kvi, recurrent_state) = ( + window_state # pytype: disable=attribute-error + ) + + # Get the size of the sliding window for pos bias, dropout, & causal mask. + (num_queries, num_keys) = attention.sliding_attention_window_shape( + (keys, values, importance), + prev_kvi, + queries, + window_length=self.window_length, + ) + kq_relative_offset = num_keys - num_queries + + # Get the relative position bias. + # The bias doesn't depend on the query content, and so can be precomputed. + if self.relative_positions is not None: + rel_position_bias = self.relative_positions( + num_queries, num_keys, bidirectional=False + ) + else: + rel_position_bias = None + + # Get causal mask. + if self.use_causal_mask: + causal_mask = position.causal_mask( + num_queries, num_keys, window_length=self.window_length + ) + else: + causal_mask = None + + # Apply dropout to the attention matrix. + # The mask will be broadcast across batches and windows. + if self.attn_dropout_rate > 0.0 and is_training: + dropout_rng = self.make_rng("dropout") + attn_shape = (self.num_heads, num_queries, num_keys) + dropout_multiplier = nn_components.dropout_multiplier_mask( + dropout_rng, self.attn_dropout_rate, attn_shape, self.dtype + ) + else: + dropout_multiplier = None + + # Load and store values into external memory, if memory is not None. + # ------------------------------------------------------------------ + (mode, _, update_memory) = self._get_cache_name_from_mode(self.mode) + external_kv = self._query_external_memory( + keys, + values, + queries, + start_of_sequence=start_of_sequence, + mode=mode, + update_memory=decoder_state is None and update_memory, + ) + + if ( + self.memory is not None + and self.memory_combine_with_local == "TRAINABLE_WEIGHTED_MEAN" + ): + external_memory_bias = jnp.asarray(self.memory_bias, dtype=self.dtype) + external_memory_bias = jnp.reshape( + external_memory_bias, (1, 1, num_heads, 1) + ) + external_memory_bias = jax.nn.sigmoid(external_memory_bias) + else: + external_memory_bias = None + + # Compute the number of windows. + # ------------------------------ + if sequence_length < self.window_length: + num_windows = 1 # Happens with autoregressive decoding. + elif sequence_length == self.window_length: + num_windows = 1 + if self.use_long_xl_architecture: + assert prev_kvi is not None + else: + if not self.use_long_xl_architecture: + raise ValueError("Can only use sliding window with Transformer XL.") + num_windows = sequence_length // self.window_length + if (num_windows * self.window_length) != sequence_length: + raise ValueError( + f"Window length {self.window_length} must be a " + + f"multiple of sequence length {sequence_length}" + ) + logging.info("tlayer: num_windows = %d.", num_windows) + + # Define the function to do attention within a single window. + # --------------------------------------------------------- + def single_window_attention( + carry: tuple[Array, Array], inputs_w: tuple[Array, Array] + ) -> tuple[tuple[Array, Array], tuple[Array, Array]]: + # This function uses the following variables from the outer scope. + # They are listed here for clarity. + nonlocal rel_position_bias + nonlocal causal_mask + nonlocal kq_relative_offset + nonlocal dropout_multiplier + nonlocal attention_scale_factors + nonlocal external_memory_bias + nonlocal cross_attention_kv # externally supplied. + + # keys,values,queries over the whole sequence will be split into chunks. + # xs_w, kvqi_w, etc. are the chunk for the current window. + (prev_kvi_w, rec_state) = carry # carried from one window to the next. + (kvqi_w, external_kv_w) = inputs_w # inputs to the current window. + # (keys_curr_w, values_curr_w, _, _, importance_curr_w) = kvqi_w + + # Concatenate keys,values from the previous window with the current + # window to implement sliding window attention. + (kvqi_w, next_kvi_w) = attention.concat_kvqi(kvqi_w, prev_kvi_w) + (keys_w, values_w, queries_w, queries2_w, importance_w) = kvqi_w + + # Perform recurrent attention within the current window to get the next + # recurrent state, and set up cross attention. + if rec_state is not None: + logging.info("tlayer: recurrent attention.") + + # NOTE -- recurrent states and input tokens are handled separately, + # because they have separate learned positional embeddings. Due to + # the way TransformerBase does cross-attention, this means that we use + # separate key,value layers for rec_state and tokens_w. + + # Keys, values, queries from recurrent state. + logging.info("tlayer: recurrent kvq.") + rec_kvq = self.recurrent_tbase.kvq(rec_state) + r_scale_factors = self.recurrent_tbase.attention_scale_factors() + (r_keys, r_values, r_queries, r_queries2) = rec_kvq + + # Joint attention over both recurrent states and input tokens. + logging.info("tlayer: recurrent self-attention.") + r_attn_ys = attention.simple_attention( + r_keys, + r_values, + r_queries, + None, + scale_factor=r_scale_factors[0], + dtype=self.dtype, + ) + + logging.info("tlayer: recurrent cross-attention.") + r_cross_attn_ys = attention.simple_attention( + keys_w, + values_w, + r_queries2, + importance_w, + scale_factor=r_scale_factors[1], + dtype=self.dtype, + ) + + # Recurrent post-attention FFN. + logging.info("tlayer: recurrent ffn.") + next_rec_state = self.recurrent_tbase.post_attn_ffn( + rec_state, r_attn_ys, r_cross_attn_ys + ) + + # Get keys and values for cross-attention from recurrent state. + assert cross_attention_kv is None + local_cross_attention_kv = (r_keys, r_values) + else: + # Get keys and values for cross-attention from external argument. + next_rec_state = None + local_cross_attention_kv = cross_attention_kv + + # If using RoPE, keys and queries are rotated before self-attention. + if self.relative_position_type == "rotary": + logging.info( + "Using rotary position encodings (RoPE), offset = %d", + kq_relative_offset, + ) + (keys_w, queries_w) = position.rotate_kq( + keys_w, queries_w, max_wavelength=10_000, offset=kq_relative_offset + ) + + # Self-attention over input tokens. + logging.info("tlayer: self-attention.") + attn_ys_w = attention.simple_attention( + keys_w, + values_w, + queries_w, + importance_w, + relative_position_bias=rel_position_bias, + scale_factor=attention_scale_factors[0], + causal_mask=causal_mask, + dropout_multiplier=dropout_multiplier, + dtype=self.dtype, + ) + + # Attention over external memory. + if external_kv_w is not None: + (external_keys_w, external_values_w) = external_kv_w + y_ext = attention.external_attention( + external_keys_w, + external_values_w, + queries_w, + scale_factor=attention_scale_factors[0], + ) + if external_memory_bias is not None: + ebias = external_memory_bias + attn_ys_w = (attn_ys_w * (1 - ebias)) + (y_ext * ebias) + elif self.memory_combine_with_local == "ADD": + attn_ys_w += y_ext + elif self.memory_combine_with_local == "STOP_FORWARD": + attn_ys_w = y_ext + (attn_ys_w - jax.lax.stop_gradient(attn_ys_w)) + else: + raise ValueError( + f"Unexpected setting: {self.memory_combine_with_local = }" + ) + + # Cross attention from input tokens to encoder or recurrent state. + if local_cross_attention_kv is not None: + logging.info("tlayer: cross-attention.") + (c_keys, c_values) = local_cross_attention_kv + + # Cross-attention using queries2. + cross_attn_ys_w = attention.simple_attention( + c_keys, + c_values, + queries2_w, + None, + scale_factor=attention_scale_factors[1], + dtype=self.dtype, + ) + else: + cross_attn_ys_w = None + + # End function single_window_attention(...) + return ((next_kvi_w, next_rec_state), (attn_ys_w, cross_attn_ys_w)) + + # Initialize recurrent_tbase before calling jax.lax.scan. + # Otherwise flax will throw a tantrum. + if ( + self.recurrent_attention + and 0 <= self.max_unrolled_windows + and self.max_unrolled_windows < num_windows + ): + logging.info("tlayer: force initialization of recurrent_tbase.") + self.recurrent_tbase.force_init(recurrent_state) + + # Perform sliding window attention over all keys,values,queries. + # -------------------------------------------------------------- + initial_carry = (prev_kvi, recurrent_state) # window state. + kvqi = (keys, values, queries, queries2, importance) + attn_inputs = (kvqi, external_kv) + (next_carry, attn_outputs) = attention.split_and_scan( + single_window_attention, + initial_carry, + attn_inputs, + sections=num_windows, + axis=1, + max_unrolled_windows=self.max_unrolled_windows, + ) + (attn_ys, cross_attn_ys) = attn_outputs + + logging.info("tlayer: End windows.") + + # Post-attention MLP, resnet, and FFN. + # ------------------------------------ + logging.info("tlayer: final FFN.") + ys = self.tbase.post_attn_ffn(xs, attn_ys, cross_attn_ys) + + # Compute importance scores for each token if requested. + if self.compute_importance: + (batch_size, sequence_length, _) = ys.shape + importance_score = self.importance_layer(ys) + importance_score = importance_score.reshape((batch_size, sequence_length)) + else: + importance_score = None + + next_window_state = next_carry if window_state is not None else None + viz_dict = {} # Visualizations, not currently enabled. + return (ys, importance_score, next_window_state, decoder_state, viz_dict) + + def init_decoder_state_vanilla( + self, sequence_length: int, start_of_sequence: Array + ) -> DecoderState: + """Initialize decoder state for autoregressive generation. + + Args: + sequence_length: The maximum length of the sequence to generate. + start_of_sequence: Array of boolean of shape (batch_size,) True if + starting a new sequence (with no prefix). + + Returns: + A state object that can be passed to __call__. + """ + + if not self.use_causal_mask: + raise ValueError("Generator must have been trained with a causal mask.") + + # Get relative position bias. + rel_position_bias = self.relative_positions( + 1, self.window_length, offset=self.window_length, bidirectional=False + ) + rel_position_bias = jnp.tile(rel_position_bias, (self.batch_size, 1, 1, 1)) + + # Initialize autoregressive storage for (key, value) pairs. + # Include space for a prefix of window_length tokens. + num_keys = sequence_length + self.window_length + stored_shape = (self.batch_size, num_keys, self.num_heads, self.head_size) + stored_keys = jnp.zeros(stored_shape, dtype=self.dtype) + stored_values = jnp.zeros(stored_shape, dtype=self.dtype) + + recurrent_kvq = None + current_index = jnp.array([self.window_length] * self.batch_size) + + decoder_state_dict = { + "keys": stored_keys, + "values": stored_values, + "current_index": current_index, + "relative_position_bias": rel_position_bias, + "recurrent_kvq": recurrent_kvq, + } + return DecoderState(decoder_state_dict)