Replies: 2 comments 2 replies
-
Thanks a lot for the detailed feedback, even including a nice proposal how to update the API! Classes vs. InstancesThe biggest difference in your design, as compared to the existing The advantage of using classes is that we declare collections of metrics declaratively, which feels very Flaxy (from @flax.struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss") Parametrized MetricsThere are different ways to parametriz an existing EfficiencyThanks for highlighting the issue with the jitted A minimal way of extending the existing API would be to add the following class method: class Metric:
# [...]
@classmethod
def empty(cls) -> "Metric":
"""Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op)."""
raise NotImplementedError("Must override empty()") Every sub-classed metric that wants to make use of the added API would then need to implement this new class method. For example, the @flax.struct.dataclass
class Average(Metric):
# [...]
@classmethod
def empty(cls) -> Metric:
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) Finally, the class @flax.struct.dataclass
class Collection:
# [...]
@classmethod
def empty(cls) -> "Collection":
return cls(
_reduction_counter=_ReductionCounter(jnp.array(1)),
**{
metric_name: metric.empty()
for metric_name, metric in cls.__annotations__.items()
}) So finally we can move the from clu import metrics
import flax
import jax
@flax.struct.dataclass # required for jax.tree_*
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss")
def eval_step(ms, model, variables, inputs, labels):
loss, logits = get_loss_and_logits(model, variables, inputs, labels)
return ms.merge(Metrics.gather_from_model_output(
loss=loss, logits=logits, labels=labels))
p_eval_step = jax.pmap(
eval_step, axis_name="batch", static_broadcasted_argnums=0)
def evaluate(model, p_variables, test_ds):
ms = Metrics.empty()
for inputs, labels in test_ds:
ms = flax.jax_utils.unreplicate(
p_eval_step(ms, model, p_variables, inputs, labels))
return ms.compute() SummaryI think this small API extension would address your concerns 2 (concern 1 is already covered in the existin API). I would prefer to keep as much as possible from the existing API because we already have a lot of users using that API and updating them to a new API would be very costly. Even worse, the functionality provided by The proposed API change is purely additional, so users who would do the metric summation outside their jitted |
Beta Was this translation helpful? Give feedback.
-
Hey @andsteing, thanks for the detailed response! I understand that drastically changing the API might be challenging or even impossible given it could break Google internal code. I do have a couple of additional points I will mention but in the end I think your TypingWhen I was playing with @flax.struct.dataclass
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss") Is actually disliked/not encouraged by some Python linters, I don't know if there is a PEP for this but Pylance warn against this pattern: Have not checked with EmptyI like this solution. The nice thing about having access to an instance thanks to def eval_step(ms: Collection, model, variables, inputs, labels):
loss, logits = get_loss_and_logits(model, variables, inputs, labels)
updates = ms.gather_from_model_output(loss=loss, logits=logits, labels=labels)
return ms.merge(updates) I like it. Parametrization via function local classesEdit: No longer relevant as I saw the strategy in a colab you shared: show original responseTried to create a simple (possibly flawed) `BinaryAccuracy` with a `threshold` parameter in `clu`, this is what I got:@flax.struct.dataclass
class BinaryAccuracy(metrics.Average):
@classmethod
def from_model_outputs(cls, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs):
values= ((logits > 0.5) == labels).astype(jnp.float32)
return super().from_model_output(values, **kwargs)
@staticmethod
def with_params(threshold: float = 0.5):
@flax.struct.dataclass
class BinaryAccuracyWithParams(metrics.Average):
@classmethod
def from_model_outputs(cls, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs):
values= ((logits > threshold) == labels).astype(jnp.float32)
return super().from_model_output(values, **kwargs)
return BinaryAccuracyWithParams I am not to happy about the approach, maybe it can be cleaned up to avoid code duplication but it feels a bit more complex than having instances which by nature are easy to parametrize. Some thoughts (opinion)Feel free to ignore this section, it just some random thoughts I've had during the process. Asymmetry between Metric and Collection APIsI am very curious why either Collection-like API via instancesThis is probably not important but I'll just mention this in case you are interested, I did try to mimic show codeimport numpy as np
from flax_tools.metrics import Metrics, Accuracy, Mean
loss = np.random.uniform(size=(10,))
logits = np.random.uniform(size=(10, 10))
labels = np.random.randint(0, 10, size=(10,))
metrics = Metrics.new(
[
Accuracy.new(),
Mean.new(name="loss").on_args("loss"),
]
).reset()
metrics = metrics.update(preds=logits, target=labels, loss=loss)
logs = metrics.compute() # e.g: {'accuracy': 0.3, 'loss': 0.47997332} |
Beta Was this translation helpful? Give feedback.
-
Current State
Metric
fromclu
currently exposes the following API:Documentation currently suggests they are used like this:
However, this if you try to implement it in terms of a realistic jitted
eval_step
this pattern become more complex:This has the following downsides:
eval_step
always has to recompile twice asmetric
will change fromNone
in the first step to aMetric
instance from then on.Metrics
cannot be (easily) parametrized sinceMetricClass.from_model_output
takes no parameters other than the actual values. If we take a look at a more complex metric such as tf.keras.metrics.BinaryIoU we see it has a couple of parameters such asthreshold
. I might be missing something but I don't see an easy way of implementing this inclu
.Proposal
My suggestion to solve both is the following API:
This has the following differences:
reset
method which should leave the metric in a neutral/zero state.update
knows how update current state based on incoming values.from_model_output
is replaced withbatch_updates
which doesreset
+update
.Example
A simple implementation of
Accuracy
could be:Now the previous can be slightly simplified example can be slightly simplified to:
For a non-distributed setup you can even just use
.update
directly since you don't need to synchronize metric state (batch_updates
) between devices:Parametrized Metrics
Now the obvious benefit of being able to instantiate a
Metric
from outside is that you can define parametrized metrics e.g. you could implement anAccuracy
metric that with atopk
parameter:And use it like this:
Reference Implementation
I've been playing around with this definition of
Metric
in this non-published repo calledflax-tools
, you can check the definition ofMetric
and implementation of a couple of non-trivial metrics ported from Treex in flax_tools/metrics.cc: @jheek @marcvanzee
Beta Was this translation helpful? Give feedback.
All reactions