Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Jan 15, 2023
1 parent 02933a9 commit 6f3bd0a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 32 deletions.
36 changes: 22 additions & 14 deletions sparsetensorviz/notebooks/Example_Rank4-bundled.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@
"import IPython\n",
"from sphinxcontrib.svgbob._svgbob import to_svg\n",
"\n",
"IPython.display.SVG(to_svg(\n",
"r\"\"\"\n",
"IPython.display.SVG(\n",
" to_svg(\n",
" r\"\"\"\n",
" .----.\n",
" | |\n",
" | v\n",
Expand All @@ -164,8 +165,9 @@
" .--+-------.\n",
" |Sparse[n3]|\n",
" `----------'\n",
"\"\"\"\n",
"))"
" \"\"\"\n",
" )\n",
")"
]
},
{
Expand Down Expand Up @@ -633,20 +635,21 @@
"# Helper function\n",
"seen_taco = {}\n",
"\n",
"\n",
"def to_tim(st):\n",
" rv = []\n",
" for gp in st.bundled_groups:\n",
" if 'expanded=True' in gp:\n",
" if \"expanded=True\" in gp:\n",
" return\n",
" num = int(gp[:-1].split('(', 1)[1])\n",
" num = int(gp[:-1].split(\"(\", 1)[1])\n",
" if gp.startswith(\"Coord\"):\n",
" rv.extend([\"index\"]*num)\n",
" rv.extend([\"index\"] * num)\n",
" elif gp.startswith(\"Full\"):\n",
" rv.extend([\"full\"]*num)\n",
" rv.extend([\"full\"] * num)\n",
" elif gp.startswith(\"Sparse\"):\n",
" rv.extend([\"sparse\"]*num)\n",
" rv.extend([\"sparse\"] * num)\n",
" elif gp.startswith(\"Hyper\"):\n",
" rv.extend([\"index\"]*(num-1))\n",
" rv.extend([\"index\"] * (num - 1))\n",
" rv.append(\"hyper\")\n",
" else:\n",
" raise ValueError(f\"Bad group: {gp}\")\n",
Expand All @@ -665,7 +668,10 @@
" tim_structure = to_tim(st)\n",
" if align:\n",
" if tim_structure is not None:\n",
" widths = [max(len(str(x)), len(y), len(z)) for x, y, z in zip(st.structure, st.taco_structure, tim_structure)]\n",
" widths = [\n",
" max(len(str(x)), len(y), len(z))\n",
" for x, y, z in zip(st.structure, st.taco_structure, tim_structure)\n",
" ]\n",
" else:\n",
" widths = [max(len(str(x)), len(y)) for x, y in zip(st.structure, st.taco_structure)]\n",
" structure = \", \".join(\n",
Expand All @@ -674,9 +680,11 @@
" if tim_structure is None:\n",
" tim_structure = \"N/A\"\n",
" else:\n",
" tim_structure = \"[\" + \", \".join(\n",
" sparse.rjust(width) for sparse, width in zip(tim_structure, widths)\n",
" ) + \"]\"\n",
" tim_structure = (\n",
" \"[\"\n",
" + \", \".join(sparse.rjust(width) for sparse, width in zip(tim_structure, widths))\n",
" + \"]\"\n",
" )\n",
" taco_structure = \", \".join(\n",
" sparse.rjust(width) for sparse, width in zip(st.taco_structure, widths)\n",
" )\n",
Expand Down
9 changes: 6 additions & 3 deletions sparsetensorviz/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ inline-quotes = "
exclude =
versioneer.py,
ignore =
E203, # whitespace before ':'
E231, # Multiple spaces around ","
W503, # line break before binary operator
E203,
E231,
W503,
B020
# E203: whitespace before ':'
# E231: Multiple spaces around ","
# W503: line break before binary operator
per-file-ignores =
__init__.py:F401
Expand Down
4 changes: 3 additions & 1 deletion sparsetensorviz/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import versioneer

install_requires = open("requirements.txt").read().strip().split("\n")
with open("requirements.txt") as f:
install_requires = f.read().strip().split("\n")

extras_require = {
"test": ["pytest"],
"viz": ["sphinxcontrib-svgbob"],
Expand Down
30 changes: 18 additions & 12 deletions sparsetensorviz/sparsetensorviz/_bundle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import re

from ._core import SparseTensor
Expand All @@ -24,7 +23,8 @@ def match(cls, s):


class AllFull(MatcherBase):
""" F[-F] """
"""F[-F]"""

pattern = re.compile("^F(?P<F>(-F)*)$")

def __new__(cls, ma, s):
Expand All @@ -33,7 +33,8 @@ def __new__(cls, ma, s):


class AllCoord(MatcherBase):
""" S[-S] """
"""S[-S]"""

pattern = re.compile("^(?P<S>(S-)*)S$")

def __new__(cls, ma, s):
Expand All @@ -42,7 +43,8 @@ def __new__(cls, ma, s):


class InitSparse(MatcherBase):
""" C-[C-] """
"""C-[C-]"""

pattern = re.compile("^(?P<C>(C-)+)")

def __new__(cls, ma, s):
Expand All @@ -51,7 +53,8 @@ def __new__(cls, ma, s):


class CoordSparse(MatcherBase):
""" [S-]DC-C-[C-] """
"""[S-]DC-C-[C-]"""

pattern = re.compile("^(?P<S>(S-)*)DC-(?P<C>(C-)+)")

def __new__(cls, ma, s):
Expand All @@ -61,7 +64,8 @@ def __new__(cls, ma, s):


class CoordSparseExpanded(MatcherBase):
""" [S-]S-C-[C-] """
"""[S-]S-C-[C-]"""

pattern = re.compile("^(?P<S>(S-)+)(?P<C>(C-)+)")

def __new__(cls, ma, s):
Expand All @@ -71,7 +75,8 @@ def __new__(cls, ma, s):


class HyperSparse(MatcherBase):
""" [S-]DC- """
"""[S-]DC-"""

pattern = re.compile("^(?P<S>(S-)*)DC-")

def __new__(cls, ma, s):
Expand All @@ -80,7 +85,8 @@ def __new__(cls, ma, s):


class CoordFull(MatcherBase):
""" [S-]S-F[-F] """
"""[S-]S-F[-F]"""

pattern = re.compile("^(?P<S>(S-)+)F(?P<F>(-F)*)$")

def __new__(cls, ma, s):
Expand All @@ -99,12 +105,12 @@ def to_bundled_groups(s):
if "-" not in s:
s = "-".join(s)
orig_s = s
if (ma := AllFull.match(s)): # All F
if ma := AllFull.match(s): # All F
return AllFull(ma, s)
if (ma := AllCoord.match(s)): # All S
if ma := AllCoord.match(s): # All S
return AllCoord(ma, s)
rv = []
if (ma := InitSparse.match(s)): # Begins with C
if ma := InitSparse.match(s): # Begins with C
rv.extend(InitSparse(ma, s))
s = trim(ma, s)
matchers = [
Expand All @@ -116,7 +122,7 @@ def to_bundled_groups(s):
]
while s:
for matcher in matchers:
if (ma := matcher.match(s)):
if ma := matcher.match(s):
rv.extend(matcher(ma, s))
s = trim(ma, s)
break
Expand Down
4 changes: 2 additions & 2 deletions sparsetensorviz/sparsetensorviz/_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,12 @@ def to_text(self, *, squared=False, compact=None, as_taco=False, as_groups=False
if as_taco:
combined = [row[xoffsets[1] :] for row in combined]
elif as_groups:
nums = [int(gp[:-1].split('(', 1)[1].split(',', 1)[0]) for gp in self.bundled_groups]
nums = [int(gp[:-1].split("(", 1)[1].split(",", 1)[0]) for gp in self.bundled_groups]
trim_ranges = []
i = 0
for num in nums:
for j in range(i, i + num - 1):
trim_ranges.append((xoffsets[2*j+1]-1, xoffsets[2*j+2]+1))
trim_ranges.append((xoffsets[2 * j + 1] - 1, xoffsets[2 * j + 2] + 1))
i += num

def trim(row, start, stop):
Expand Down

0 comments on commit 6f3bd0a

Please sign in to comment.