Skip to content

Commit a4d253e

Browse files
committedMar 10, 2025·
Supporting flags on model that turns off optimizations.
1 parent b657c3d commit a4d253e

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed
 

‎WORKSPACE

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
55

66
git_repository(
77
name = "ccv",
8-
commit = "b2bcf7f26cb8967ffcad71218e41431bf40b78fe",
8+
commit = "84769a56bacd851ad26d1d2843fa5075aff661e5",
99
remote = "https://github.com/liuliu/ccv.git",
10-
shallow_since = "1740446209 -0500",
10+
shallow_since = "1741587468 -0400",
1111
)
1212

1313
load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")

‎deps.bzl

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def s4nnc_deps():
1717
git_repository,
1818
name = "ccv",
1919
remote = "https://github.com/liuliu/ccv.git",
20-
commit = "b2bcf7f26cb8967ffcad71218e41431bf40b78fe",
21-
shallow_since = "1740446209 -0500",
20+
commit = "84769a56bacd851ad26d1d2843fa5075aff661e5",
21+
shallow_since = "1741587468 -0400",
2222
)
2323

2424
_maybe(

‎nnc/Model.swift

+25
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,31 @@ public class Model: AnyModel {
244244
}
245245
}
246246

247+
public struct EnableBits: OptionSet, CaseIterable {
248+
public let rawValue: Int32
249+
public init(rawValue: Int32) {
250+
self.rawValue = rawValue
251+
}
252+
/**
253+
* Disable Optimizations
254+
*/
255+
public static let disableOpt = EnableBits(
256+
rawValue: Int32(CCV_NNC_GRAPH_EXEC_DISABLE_OPT))
257+
public static let allCases: [EnableBits] = [.disableOpt]
258+
}
259+
/**
260+
* The flags for the underlying execution node.
261+
*/
262+
public var flags: EnableBits {
263+
get {
264+
let value = ccv_cnnp_model_flags(cModel)
265+
return EnableBits(rawValue: value)
266+
}
267+
set {
268+
let _ = ccv_cnnp_model_set_flags(cModel, newValue.rawValue)
269+
}
270+
}
271+
247272
public enum ParametersType {
248273
case weight
249274
case bias

‎nnc/ModelAddons.swift

+12-5
Original file line numberDiff line numberDiff line change
@@ -1148,14 +1148,20 @@ public final class Concat: Model {
11481148
}
11491149

11501150
extension Functional {
1151-
public static func concat(axis: Int, _ inputs: ModelIOConvertible...) -> Model.IO {
1152-
return Concat(axis: axis).apply(inputs)
1151+
public static func concat(
1152+
axis: Int, _ inputs: ModelIOConvertible..., flags: Model.EnableBits = []
1153+
) -> Model.IO {
1154+
let concat = Concat(axis: axis)
1155+
concat.flags = flags
1156+
return concat.apply(inputs)
11531157
}
11541158

11551159
public static func concat<T: DynamicGraph.TensorGroup>(
1156-
axis: Int, _ inputs: T..., streamContext: StreamContext? = nil
1160+
axis: Int, _ inputs: T..., flags: Model.EnableBits = [], streamContext: StreamContext? = nil
11571161
) -> T {
1158-
let outputs = Concat(axis: axis)(
1162+
let concat = Concat(axis: axis)
1163+
concat.flags = flags
1164+
let outputs = concat(
11591165
inputs: inputs[0], Array(inputs.suffix(from: 1)), streamContext: streamContext)
11601166
return T(outputs[0])
11611167
}
@@ -1535,7 +1541,8 @@ public final class Debug: Model {
15351541
extension ModelIOConvertible {
15361542
/// Move the value to another Model.IO. This is a special operation that can perform optimizations
15371543
/// violates SSA. Use it with extreme care.
1538-
public func debug(name: String = "", _ callback: @escaping ([AnyTensor?], StreamContext?) -> Void) -> Model.IO
1544+
public func debug(name: String = "", _ callback: @escaping ([AnyTensor?], StreamContext?) -> Void)
1545+
-> Model.IO
15391546
{
15401547
return Debug(name: name, callback)(self)
15411548
}

0 commit comments

Comments
 (0)
Please sign in to comment.