diff --git a/src/Loggers/LogHParams.jl b/src/Loggers/LogHParams.jl new file mode 100644 index 00000000..c992225b --- /dev/null +++ b/src/Loggers/LogHParams.jl @@ -0,0 +1,187 @@ +PLUGIN_NAME = "hparams" +PLUGIN_DATA_VERSION = 0 + +EXPERIMENT_TAG = "_hparams_/experiment" +SESSION_START_INFO_TAG = "_hparams_/session_start_info" + + +struct DiscreteDomain{T} + values::AbstractVector{T} +end + +DiscreteDomainElem = Union{String, Float64, Bool} + +hparams_datatype_sym(::Type{String}) = :DATA_TYPE_STRING +hparams_datatype_sym(::Type{Float64}) = :DATA_TYPE_FLOAT64 +hparams_datatype_sym(::Type{Bool}) = :DATA_TYPE_BOOL + +function hparams_datatype(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem + tensorboard.hparams._DataType[hparams_datatype_sym(T)] +end + +ProtoBuf.google.protobuf.Value(x::Number) = Value(number_value=x) +ProtoBuf.google.protobuf.Value(x::Bool) = Value(bool_value=x) +ProtoBuf.google.protobuf.Value(x::AbstractString) = Value(string_value=x) +function ProtoBuf.google.protobuf.Value(x) + @warn "Cannot create a ProtoBuf.google.protobuf.Value of type $(typeof(x)), defaulting to null." + Value(null_value=Int32(0)) +end + + +function ProtoBuf.google.protobuf.ListValue(domain::DiscreteDomain{T}) where T <: DiscreteDomainElem + ProtoBuf.google.protobuf.ListValue( + values = ProtoBuf.google.protobuf.Value.(domain.values) + ) +end + +struct IntervalDomain + min_value::Float64 + max_value::Float64 +end + +Interval(d::IntervalDomain) = Interval(min_value=d.min_value, max_value=d.max_value) + +HParamDomain = Union{IntervalDomain, DiscreteDomain} + +struct HParam + name::AbstractString + domain::HParamDomain + display_name::AbstractString + description::AbstractString +end + + +function HParamInfo(hparam::HParam) + domain = hparam.domain + domain_kwargs = if isa(domain, IntervalDomain) + (;domain_interval = Interval(domain)) + else + @assert isa(domain, DiscreteDomain) + (_type = hparams_datatype(domain), + domain_discrete = ProtoBuf.google.protobuf.ListValue(domain)) + end + HParamInfo(;name = hparam.name, + description = hparam.description, + display_name = hparam.display_name, + domain_kwargs...) +end + +struct Metric + tag::AbstractString + group::AbstractString + display_name::AbstractString + description::AbstractString + dataset_type::Symbol + + function Metric(tag::AbstractString, + group::AbstractString, + display_name::AbstractString, + description::AbstractString, + dataset_type::Symbol) + valid_dataset_types = keys(tensorboard.hparams.DatasetType) + if dataset_type ∉ valid_dataset_types + throw(ArgumentError("dataset_type of $(dataset_type) is not one of $(map(string, valid_dataset_types)).")) + else + new(tag, group, display_name, description, dataset_type) + end + end +end + +function MetricInfo(metric::Metric) + MetricInfo( + name=MetricName( + group=metric.group, + tag=metric.tag, + ), + display_name=metric.display_name, + description=metric.description, + dataset_type=tensorboard.hparams.DatasetType[metric.dataset_type] + ) +end + +struct HParamsConfig + hparams::AbstractVector{HParam} + metrics::AbstractVector{Metric} + time_created_secs::Float64 +end + +function SummaryMetadata(hparams_plugin_data::HParamsPluginData) + SummaryMetadata( + plugin_data = SummaryMetadata_PluginData( + plugin_name = PLUGIN_NAME, + content = serialize_proto(hparams_plugin_data) + ) + ) +end + +function Summary_Value(tag, hparams_plugin_data::HParamsPluginData) + null_tensor = TensorProto(dtype = _DataType.DT_FLOAT, tensor_shape = TensorShapeProto(dim=[])) + Summary_Value( + tag = tag, + metadata = SummaryMetadata(hparams_plugin_data), + tensor = null_tensor + ) +end + +function log_hparams_config(logger::TBLogger, + hparams_config::HParamsConfig; + step=nothing) + summ = SummaryCollection( + hparams_config_summary(hparams_config) + ) + write_event(logger.file, make_event(logger, summ, step=step)) +end + +function hparams_config_summary(config::HParamsConfig) + Summary_Value( + EXPERIMENT_TAG, + HParamsPluginData( + version = PLUGIN_DATA_VERSION, + experiment = Experiment( + hparam_infos = HParamInfo.(config.hparams), + metric_infos = MetricInfo.(config.metrics), + time_created_secs = config.time_created_secs + ) + ) + ) +end + +function log_hparams(logger::TBLogger, + hparams_dict::Dict{HParam, Any}, + group_name::AbstractString, + trial_id::AbstractString, + start_time_secs::Union{Float64, Nothing}; + step=nothing) + summ = SummaryCollection( + hparams_summary(hparams_dict, + group_name, + trial_id, + start_time_secs) + ) + write_event(logger.file, make_event(logger, summ, step=step)) +end + +function hparams_summary(hparams_dict::Dict{HParam, Any}, + group_name::AbstractString, + trial_id::AbstractString, + start_time_secs=Union{Float64, Nothing}) + if start_time_secs === nothing + start_time_secs = time() + end + + Summary_Value( + SESSION_START_INFO_TAG, + HParamsPluginData( + version = PLUGIN_DATA_VERSION, + session_start_info = SessionStartInfo( + group_name = group_name, + start_time_secs = start_time_secs, + hparams = Dict( + hparam.name => ProtoBuf.google.protobuf.Value(val) + for (hparam, val) ∈ hparams_dict + ) + ) + ) + ) +end + diff --git a/src/TensorBoardLogger.jl b/src/TensorBoardLogger.jl index 6b0b520d..30761619 100644 --- a/src/TensorBoardLogger.jl +++ b/src/TensorBoardLogger.jl @@ -20,7 +20,8 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info, export TBLogger, reset!, set_step!, increment_step!, set_step_increment! export log_histogram, log_value, log_vector, log_text, log_image, log_images, - log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar + log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar, + log_hparams, log_hparams_config export map_summaries export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC, @@ -30,13 +31,20 @@ export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC, export tb_multiline, tb_margin # Wrapper types -export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios +export TBText, TBVector, TBHistogram, TBImage, TBImages, TBAudio, TBAudios, TBHParams # Protobuffer definitions for tensorboard include("protojl/tensorboard/tensorboard.jl") using .tensorboard: Summary_Value, GraphDef, Summary, Event, SessionLog_SessionStatus, SessionLog using .tensorboard: TensorShapeProto_Dim, TensorShapeProto, TextPluginData using .tensorboard: TensorProto, SummaryMetadata, SummaryMetadata_PluginData, _DataType +using .tensorboard.hparams: HParamsPluginData, Experiment, SessionStartInfo, SessionEndInfo, HParamInfo, MetricInfo, HParamInfo, Interval, MetricName, DatasetType +import .tensorboard.hparams +import .tensorboard: SummaryMetadata, Summary +import .tensorboard.hparams: HParamInfo, MetricInfo, Interval + +using ProtoBuf +import ProtoBuf.google.protobuf: Value, ListValue include("PNG.jl") using .PNGImage @@ -58,6 +66,9 @@ include("Loggers/LogEmbeddings.jl") # Custom Scalar Plugin include("Loggers/LogCustomScalar.jl") +include("Loggers/LogHParams.jl") + + include("logger_dispatch.jl") include("logger_dispatch_overrides.jl") diff --git a/src/logger_dispatch_overrides.jl b/src/logger_dispatch_overrides.jl index 9396911c..a6b8abbb 100644 --- a/src/logger_dispatch_overrides.jl +++ b/src/logger_dispatch_overrides.jl @@ -207,3 +207,22 @@ preprocess(name, val::TBVector{T,N}, data) where {T<:Complex,N} = push!(data, name*"/re"=>TBVector(real.(content(val))), name*"/im"=>TBVector(imag.(content(val)))) summary_impl(name, val::TBVector) = histogram_summary(name, collect(0:length(val.data)), val.data) + +########## Hyperparameters ######################## + +# FIXME: name unused? +summary_impl(name, val::HParamsConfig) = hparams_config_summary(val.data) +preprocess(name, val::HParamsConfig, data) = push!(data, name=>val) + +struct TBHParams <: WrapperLogType + # TODO: The types in the hparam domain and this dict's values are constrained. + # e.g. an hparam with a discrete domain of ["a", "b"] must have string values + # Consider ways to enforce this relationship in the type system. + data::Dict{HParam, Any} + # FIXME: group_name auto generated in the Python implementation (Tensorboard) + group_name::AbstractString + trial_id::AbstractString + start_time_secs::Union{Float64, Nothing} +end +content(x::TBHParams) = x.data +summary_impl(name, val::TBHParams) = hparams_summary(val.data, val.group_name, val.trial_id, val.start_time_secs) diff --git a/src/protojl/tensorboard/tensorboard.jl b/src/protojl/tensorboard/tensorboard.jl index 1851dba2..99a89b60 100644 --- a/src/protojl/tensorboard/tensorboard.jl +++ b/src/protojl/tensorboard/tensorboard.jl @@ -34,6 +34,10 @@ module tensorboard include("plugins/custom_scalar/layout_pb.jl") include("plugins/text/plugin_data_pb.jl") - #include("plugins/hparams/hparams.jl") - + # Needs separate module due to conflicting "_DataType" export + module hparams + include("plugins/hparams/api_pb.jl") + include("plugins/hparams/hparams_util_pb.jl") + include("plugins/hparams/plugin_data_pb.jl") + end end diff --git a/test/runtests.jl b/test/runtests.jl index f8224db0..225ddeb4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using TensorBoardLogger, Logging using TensorBoardLogger: preprocess, summary_impl +using TensorBoardLogger: IntervalDomain, DiscreteDomain, HParam, Metric, HParamsConfig using Test using TestImages using ImageCore @@ -139,7 +140,7 @@ end @test π != log_image(logger, "rand/LCN", rand(10, 3, 2), LCN, step = step) if VERSION >= v"1.3.0" - using MLDatasets: MNIST + using MLDatasets: MNIST sample = MNIST.traintensor(1:3) @test π != log_image(logger, "mnist/HWN", sample, HWN, step = step) @@ -268,6 +269,55 @@ end @test π != log_embeddings(logger, "random2", mat, step = step+1) end +@testset "HParamConfig Logger" begin + logger = TBLogger(test_log_dir*"t", tb_overwrite) + step = 1 + + interval_domain = IntervalDomain(0.1, 3.0) + hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1") + + discrete_domain_strs = ["a", "b", "c"] + discrete_domain = DiscreteDomain(discrete_domain_strs) + hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2") + + hparams = [hparam1, hparam2] + + metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION) + metrics = [metric] + hparams_config = HParamsConfig(hparams, metrics, 1.2) + ss = TensorBoardLogger.hparams_config_summary(hparams_config) + + @test isa(ss, TensorBoardLogger.Summary_Value) + @test ss.tag == TensorBoardLogger.EXPERIMENT_TAG + + # TODO: Deserialize and test more properties + + log_hparams_config(logger, hparams_config ;step=step) +end + +@testset "HParams Logger" begin + logger = TBLogger(test_log_dir*"t", tb_overwrite) + step = 1 + + interval_domain = IntervalDomain(0.1, 3.0) + hparam1 = HParam("interval_hparam", interval_domain, "display_name1", "description1") + + discrete_domain_strs = ["a", "b", "c"] + discrete_domain = DiscreteDomain(discrete_domain_strs) + hparam2 = HParam("discrete_domain_hparam", discrete_domain, "display_name2", "description2") + + hparams_dict = Dict(hparam1 => 1.2, hparam2 => "b") + + ss = TensorBoardLogger.hparams_summary(hparams_dict, "group_name", "trial_id", nothing) + + @test isa(ss, TensorBoardLogger.Summary_Value) + @test ss.tag == TensorBoardLogger.SESSION_START_INFO_TAG + + # TODO: Deserialize and test more properties + log_hparams(logger, hparams_dict, "group_name", "trial_id", nothing ;step=step) +end + + @testset "Logger dispatch overrides" begin include("test_logger_dispatch_overrides.jl") end diff --git a/test/test_logger_dispatch_overrides.jl b/test/test_logger_dispatch_overrides.jl index 9717721d..2e67a4c2 100644 --- a/test/test_logger_dispatch_overrides.jl +++ b/test/test_logger_dispatch_overrides.jl @@ -1,5 +1,6 @@ using TensorBoardLogger, Test using TensorBoardLogger: preprocess, content +using TensorBoardLogger: TBHParams using TestImages using ImageCore @testset "TBText" begin @@ -61,3 +62,21 @@ end @test first(data) == ("test/1"=>TBAudio(y, sample_rate)) @test last(data) == ("test/2"=>TBAudio(y, sample_rate)) end + +@testset "HParamsConfig" begin + data = Vector{Pair{String,Any}}() + hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1") + metric = Metric("tag", "group", "display_name", "description", :DATASET_VALIDATION) + params_config = HParamsConfig([hparam], [metric], 1.2) + preprocess("test", params_config, data) + @test first(data) == ("test"=>params_config) +end + +@testset "TBHParams" begin + data = Vector{Pair{String,Any}}() + hparam = HParam("interval_hparam", IntervalDomain(0.1, 3.0), "display_name1", "description1") + hparams_dict = Dict(hparam => 1.2) + tbh_params = TBHParams(hparams_dict, "group_name", "trial_id", nothing) + preprocess("test", tbh_params, data) + @test first(data) == ("test"=>tbh_params) +end