Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
aebadc0
Creating Runtime Module for IR to IO , adding libsvm support(wip)
pantShrey Jul 15, 2025
475ddac
corrected the translator test cases regarding data type 6 and 7
pantShrey Jul 15, 2025
8b67cf0
removed memcpy from inputs and outputs and used flatTraverse in trans…
pantShrey Jul 18, 2025
776f8ee
replaced for with while and hard coded separate methods
pantShrey Jul 18, 2025
7dfec45
logic placed in a single IO in handleSvmClassifier
pantShrey Jul 18, 2025
8dfdf40
corrected libsvm implementation returned a resource inplace of immedi…
pantShrey Jul 25, 2025
73ed66d
fixing the libsvm warning
pantShrey Jul 25, 2025
c5c6596
adding CE 3.7 RC-1 and operationsin modelir required for mnist
pantShrey Aug 7, 2025
42ab2eb
removin comments to self
pantShrey Aug 7, 2025
78e699a
Added the new operation in Translator.scala and added a few test cases
pantShrey Aug 8, 2025
cc3b13c
added mnist parsing getting wrong results
pantShrey Aug 20, 2025
8d723a7
changes to not break CI
pantShrey Aug 20, 2025
9c8fa20
changes to not break CI
pantShrey Aug 20, 2025
3440e8a
sbt prePR
pantShrey Aug 20, 2025
a04984c
Added MLPack support for conv , pool and softmax
pantShrey Aug 28, 2025
674944e
corrected build.sbt
pantShrey Aug 28, 2025
5536fe6
corrected build.sbt
pantShrey Aug 28, 2025
a99a44d
Changed mlpack usage to reduce allocation and initialisations
pantShrey Sep 2, 2025
77df50e
Added Snapshot Resolver for sn 0.5.9-SNAPSHOT
pantShrey Sep 10, 2025
ddc63a7
Added cereal as mlpack depends on cereal
pantShrey Sep 10, 2025
6887541
broke down conv and max pool in 2 phases
pantShrey Sep 25, 2025
e6dd86c
removing main.scala experiment
pantShrey Sep 25, 2025
3576791
add @define and #ifdef to enable C++ wrapper compilation and flag pro…
pantShrey Oct 4, 2025
ff306b9
feat: add utility functions and update ce version
pantShrey May 17, 2026
9266fd6
build: bump sbt-typelevel to 0.8.5, scala to 2.13.18, brew plugin to …
pantShrey May 20, 2026
5077516
fix: update mlpack Conv and MaxPool layer type aliases for 4.7.0 API
pantShrey May 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,28 @@ concurrency:

jobs:
build:
name: Build and Test
name: Test
strategy:
matrix:
os: [ubuntu-latest]
os: [ubuntu-22.04]
scala: [3, 2.13]
java: [temurin@21]
project: [rootNative]
runs-on: ${{ matrix.os }}
timeout-minutes: 60
steps:
- name: Install sbt
uses: sbt/setup-sbt@v1

- name: Checkout current branch (full)
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: Setup sbt
uses: sbt/setup-sbt@v1

- name: Setup Java (temurin@21)
id: setup-java-temurin-21
if: matrix.java == 'temurin@21'
uses: actions/setup-java@v4
uses: actions/setup-java@v5
with:
distribution: temurin
java-version: 21
Expand All @@ -60,13 +60,13 @@ jobs:

- name: Install brew formulae (ubuntu)
if: startsWith(matrix.os, 'ubuntu')
run: /home/linuxbrew/.linuxbrew/bin/brew install mlpack openblas
run: /home/linuxbrew/.linuxbrew/bin/brew install cereal libsvm mlpack openblas

- name: Check that workflows are up to date
run: sbt githubWorkflowCheck

- name: Check headers and formatting
if: matrix.java == 'temurin@21' && matrix.os == 'ubuntu-latest'
if: matrix.java == 'temurin@21' && matrix.os == 'ubuntu-22.04'
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' headerCheckAll scalafmtCheckAll 'project /' scalafmtSbtCheck

- name: nativeLink
Expand All @@ -77,11 +77,11 @@ jobs:
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test

- name: Check binary compatibility
if: matrix.java == 'temurin@21' && matrix.os == 'ubuntu-latest'
if: matrix.java == 'temurin@21' && matrix.os == 'ubuntu-22.04'
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues

- name: Generate API documentation
if: matrix.java == 'temurin@21' && matrix.os == 'ubuntu-latest'
if: matrix.java == 'temurin@21' && matrix.os == 'ubuntu-22.04'
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc

- name: Make target directories
Expand All @@ -94,7 +94,7 @@ jobs:

- name: Upload target directories
if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main')
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v5
with:
name: target-${{ matrix.os }}-${{ matrix.java }}-${{ matrix.scala }}-${{ matrix.project }}
path: targets.tar
Expand All @@ -105,22 +105,22 @@ jobs:
if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main')
strategy:
matrix:
os: [ubuntu-latest]
os: [ubuntu-22.04]
java: [temurin@21]
runs-on: ${{ matrix.os }}
steps:
- name: Install sbt
uses: sbt/setup-sbt@v1

- name: Checkout current branch (full)
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: Setup sbt
uses: sbt/setup-sbt@v1

- name: Setup Java (temurin@21)
id: setup-java-temurin-21
if: matrix.java == 'temurin@21'
uses: actions/setup-java@v4
uses: actions/setup-java@v5
with:
distribution: temurin
java-version: 21
Expand All @@ -131,7 +131,7 @@ jobs:
run: sbt +update

- name: Download target directories (3, rootNative)
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: target-${{ matrix.os }}-${{ matrix.java }}-3-rootNative

Expand All @@ -141,7 +141,7 @@ jobs:
rm targets.tar

- name: Download target directories (2.13, rootNative)
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: target-${{ matrix.os }}-${{ matrix.java }}-2.13-rootNative

Expand Down Expand Up @@ -179,22 +179,22 @@ jobs:
if: github.event.repository.fork == false && github.event_name != 'pull_request'
strategy:
matrix:
os: [ubuntu-latest]
os: [ubuntu-22.04]
java: [temurin@21]
runs-on: ${{ matrix.os }}
steps:
- name: Install sbt
uses: sbt/setup-sbt@v1

- name: Checkout current branch (full)
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0

- name: Setup sbt
uses: sbt/setup-sbt@v1

- name: Setup Java (temurin@21)
id: setup-java-temurin-21
if: matrix.java == 'temurin@21'
uses: actions/setup-java@v4
uses: actions/setup-java@v5
with:
distribution: temurin
java-version: 21
Expand Down
22 changes: 16 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@ ThisBuild / developers ++= List(
tlGitHubDev("valencik", "Andrew Valencik"),
)
ThisBuild / startYear := Some(2023)
ThisBuild / tlSonatypeUseLegacyHost := false

ThisBuild / crossScalaVersions := Seq("3.3.4", "2.13.16")

ThisBuild / crossScalaVersions := Seq("3.3.4", "2.13.18")
ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("21"))
ThisBuild / githubWorkflowBuildPreamble +=
WorkflowStep.Run(List("/home/linuxbrew/.linuxbrew/bin/brew update"), name = Some("brew update"))
ThisBuild / githubWorkflowBuildPreamble ++= nativeBrewInstallWorkflowSteps.value

ThisBuild / Test / testOptions += Tests.Argument("+l") // for munit logging

val CatsEffectVersion = "3.7-8f2b497"
val CatsEffectVersion = "3.7.0"
val CatsVersion = "2.12.0"
val MunitVersion = "1.0.4"
lazy val root = tlCrossRootProject.aggregate(ir, onnx, runtime)
Expand All @@ -40,6 +38,7 @@ lazy val onnx = project
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapb.compiler.Version.scalapbVersion % "protobuf",
"org.typelevel" %%% "cats-core" % CatsVersion,
"org.scalameta" %%% "munit" % MunitVersion % Test,
"org.typelevel" %%% "cats-effect" % CatsEffectVersion,
),
nativeConfig ~= { _.withEmbedResources(true) },
Compile / PB.generate := (Compile / PB.generate).dependsOn(Compile / downloadOnnxProto).value,
Expand Down Expand Up @@ -69,9 +68,20 @@ lazy val runtime = project
.settings(
name := "vilcacora-runtime",
libraryDependencies ++= Seq(
"org.typelevel" %%% "cats-effect-kernel" % CatsEffectVersion,
"org.typelevel" %%% "cats-effect" % CatsEffectVersion,
"org.typelevel" %%% "cats-core" % CatsVersion,
"org.scalameta" %%% "munit" % MunitVersion % Test,
"org.typelevel" %%% "keypool" % "0.5.0-RC1",
),
nativeBrewFormulas ++= Set("openblas", "mlpack"),
nativeBrewFormulas ++= Set("cereal", "openblas", "mlpack", "libsvm"),
nativeConfig ~= { c =>
c.withCompileOptions(
c.compileOptions ++ Seq(
"-fexceptions",
"-frtti",
"-Wno-inconsistent-missing-override",
"-DARMA_DONT_USE_WRAPPER",
),
)
},
)
107 changes: 107 additions & 0 deletions ir/src/main/scala/com/armanbilge/vilcacora/ir/ModelIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,24 @@ object PostTransform {
}
}

/** Representation of auto_pad attribute for Conv and Pool operations.
*/
sealed abstract class AutoPad
object AutoPad {
case object NotSet extends AutoPad
case object SameUpper extends AutoPad
case object SameLower extends AutoPad
case object Valid extends AutoPad

def fromString(s: String): Either[String, AutoPad] = s.toUpperCase match {
case "NOTSET" => Right(NotSet)
case "SAME_UPPER" => Right(SameUpper)
case "SAME_LOWER" => Right(SameLower)
case "VALID" => Right(Valid)
case other => Left(s"Unsupported AutoPad: $other")
}
}

/** An ADT representing a single operation in the computation graph.
*/
sealed abstract class Operation {
Expand Down Expand Up @@ -155,6 +173,95 @@ object Operation {
override def inputs: List[String] = List(input)
override def outputs: List[String] = List(output)
}

/** Represents a Constant operation that produces a constant tensor.
*/
final case class Constant(
output: String,
// --- Attributes ---
value: Array[Byte],
dataType: DataType,
shape: List[Int],
) extends Operation {
override def inputs: List[String] = List.empty
override def outputs: List[String] = List(output)
}

/** Represents an element-wise division operation.
*/
final case class Div(inputA: String, inputB: String, output: String) extends Operation {
override def inputs: List[String] = List(inputA, inputB)
override def outputs: List[String] = List(output)
}

/** Represents a Convolution operation.
*/
final case class Conv(
input: String,
weight: String,
bias: Option[String],
output: String,
// --- Attributes ---
autoPad: AutoPad,
dilations: List[Int],
group: Int,
kernelShape: List[Int],
pads: List[Int],
strides: List[Int],
) extends Operation {
override def inputs: List[String] = List(input, weight) ++ bias.toList
override def outputs: List[String] = List(output)
}

/** Represents a ReLU (Rectified Linear Unit) activation operation.
*/
final case class Relu(input: String, output: String) extends Operation {
override def inputs: List[String] = List(input)
override def outputs: List[String] = List(output)
}

/** Represents a MaxPool operation.
*/
final case class MaxPool(
input: String,
output: String,
// --- Attributes ---
autoPad: AutoPad,
ceilMode: Boolean,
dilations: List[Int],
kernelShape: List[Int],
pads: List[Int],
storageOrder: Int,
strides: List[Int],
) extends Operation {
override def inputs: List[String] = List(input)
override def outputs: List[String] = List(output)
}

/** Represents a Reshape operation. Takes a shape tensor as input following ONNX specification.
*/
final case class Reshape(
input: String,
shape: String, // Shape tensor input name
output: String,
// --- Attributes ---
allowzero: Boolean = false,
) extends Operation {
override def inputs: List[String] = List(input, shape)
override def outputs: List[String] = List(output)
}

/** Represents a Softmax activation operation.
*/
final case class Softmax(
input: String,
output: String,
// --- Attributes ---
axis: Int = -1, // Default axis is -1 (last dimension)
) extends Operation {
override def inputs: List[String] = List(input)
override def outputs: List[String] = List(output)
}
// Add more operations here...
}

Expand Down
37 changes: 37 additions & 0 deletions onnx/src/main/scala/vilcacora/onnx/ModelLoader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2023 Arman Bilge
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package vilcacora.onnx

import cats.effect.IO
import java.io.InputStream
import vilcacora.onnx.proto.ModelProto

object ModelLoader {

/** Loads an ONNX model from classpath as a ModelProto. Use a forward slash for resource path,
* e.g., "/mnist.onnx".
*/
def loadModelFromPath(modelPath: String): IO[ModelProto] =
IO.blocking {
val stream: InputStream = getClass.getResourceAsStream(modelPath)
if (stream == null)
throw new IllegalArgumentException(s"Resource not found: $modelPath")
try ModelProto.parseFrom(stream)

finally stream.close()
}
}
Loading