diff --git a/src/Loggers/LogHParams.jl b/src/Loggers/LogHParams.jl index c992225..3fef01e 100644 --- a/src/Loggers/LogHParams.jl +++ b/src/Loggers/LogHParams.jl @@ -11,7 +11,7 @@ end DiscreteDomainElem = Union{String, Float64, Bool} -hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING +hparams_datatype_sym(::Type{<:AbstractString}) = :DATA_TYPE_STRING hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64 hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL @@ -19,16 +19,16 @@ function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainEl tensorboard.hparams._DataType[hparams_datatype_sym(T)] end -ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x) -ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x) -ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x) -function ProtoBuf.google.protobuf.Value(x) +# custom constructors for ProtoBuf.google.protobuf.Value +function _Protobuf_Value(x) @warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null." Value(null_value=Int32(0)) end +_Protobuf_Value(x::Bool) = Value(bool_value=x) +_Protobuf_Value(x::Number) = Value(number_value=x) +_Protobuf_Value(x::AbstractString) = Value(string_value=x) - -function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem +function _Protobuf_ListValue(domain::DiscreteDomain{T})where T <: DiscreteDomainElem ProtoBuf.google.protobuf.ListValue( values = ProtoBuf.google.protobuf.Value.(domain.values) ) @@ -58,7 +58,7 @@ function HParamInfo(hparam::HParam) else @assert isa(domain, DiscreteDomain) (_type = hparams_datatype(domain), - domain_discrete = ProtoBuf.google.protobuf.ListValue(domain)) + domain_discrete = _Protobuf_ListValue(domain)) end HParamInfo(;name = hparam.name, description = hparam.description, @@ -177,7 +177,7 @@ function hparams_summary(hparams_dict::Dict{HParam, Any}, group_name = group_name, start_time_secs = start_time_secs, hparams = Dict( - hparam.name => ProtoBuf.google.protobuf.Value(val) + hparam.name => _ProtoBuf_Value(val) for (hparam, val) ∈ hparams_dict ) )