Skip to content

Commit

Permalink
Additional Fixes, fix exr output
Browse files Browse the repository at this point in the history
  • Loading branch information
JerryLingjieMei authored and pvl-bot committed Nov 11, 2024
1 parent 419f2f4 commit 6c4d2f3
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 21 deletions.
1 change: 1 addition & 0 deletions infinigen/core/nodes/node_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class Nodes:
WorldOutput = "ShaderNodeOutputWorld"
Composite = "CompositorNodeComposite"
Viewer = "CompositorNodeViewer"
CompositorMixRGB = "CompositorNodeMixRGB"

# Point
DistributePointsOnFaces = "GeometryNodeDistributePointsOnFaces"
Expand Down
21 changes: 18 additions & 3 deletions infinigen/core/rendering/post_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import logging
import os

import OpenEXR

# ruff: noqa: E402
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" # This must be done BEFORE import cv2.

Expand All @@ -30,16 +32,29 @@ def load_exr(path):
load_flow = load_exr


def load_depth(p):
return load_exr(p)[..., 0]
def load_single_channel(p):
file = OpenEXR.InputFile(str(p))
channel, channel_type = next(iter(file.header()["channels"].items()))
match str(channel_type.type):
case "FLOAT":
np_type = np.float32
case _:
np_type = np.uint8
data = np.frombuffer(file.channel(channel, channel_type.type), np_type)
dw = file.header()["dataWindow"]
sz = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
return data.reshape(sz)


load_depth = load_single_channel


def load_normals(p):
return load_exr(p)[..., [2, 0, 1]] * np.array([-1.0, 1.0, 1.0])


def load_seg_mask(p):
return load_exr(p)[..., 2].astype(np.int64)
return load_single_channel(p).astype(np.int64)


def load_uniq_inst(p):
Expand Down
25 changes: 17 additions & 8 deletions infinigen/core/rendering/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,23 @@ def configure_compositor_output(

slot_input = file_output_node.file_slots.new(socket_name)
render_socket = render_layers.outputs[socket_name]
if viewlayer_pass == "vector":
separate_color = nw.new_node(Nodes.CompSeparateColor, [render_socket])
comnbine_color = nw.new_node(
Nodes.CompCombineColor, [0, (separate_color, 3), (separate_color, 2), 0]
)
nw.links.new(comnbine_color.outputs[0], slot_input)
else:
nw.links.new(render_socket, slot_input)
match viewlayer_pass:
case "vector":
separate_color = nw.new_node(Nodes.CompSeparateColor, [render_socket])
comnbine_color = nw.new_node(
Nodes.CompCombineColor,
[0, (separate_color, 3), (separate_color, 2), 0],
)
nw.links.new(comnbine_color.outputs[0], slot_input)
case "normal":
color = nw.new_node(
Nodes.CompositorMixRGB,
[None, render_socket, (0, 0, 0, 0)],
attrs={"blend_type": "ADD"},
).outputs[0]
nw.links.new(color, slot_input)
case _:
nw.links.new(render_socket, slot_input)
file_slot_list.append(file_output_node.file_slots[slot_input.name])

slot_input = default_file_output_node.file_slots["Image"]
Expand Down
7 changes: 5 additions & 2 deletions infinigen/core/util/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from infinigen.core.nodes.node_info import DATATYPE_DIMS, DATATYPE_FIELDS

from ..nodes.node_wrangler import ng_inputs
from . import math as mutil
from .logging import Suppress

Expand Down Expand Up @@ -556,7 +555,11 @@ def get_camera_res():
def set_geomod_inputs(mod, inputs: dict):
assert mod.type == "NODES"
for k, v in inputs.items():
inputs = ng_inputs(mod.node_group)
inputs = {
s.name: s
for s in mod.node_group.interface.items_tree
if s.in_out == "INPUT"
}
if k not in inputs:
raise KeyError(f"Couldnt find {k=} in {mod.node_group.inputs.keys()=}")
soc = inputs[k]
Expand Down
17 changes: 9 additions & 8 deletions infinigen/terrain/surface_kernel/kernelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from infinigen.terrain.utils import (
NODE_ATTRS_AVAILABLE,
NODE_FUNCTIONS,
SOCKETTYPE_KERNEL,
KernelDataType,
Nodes,
SocketType,
Expand All @@ -29,6 +28,7 @@
value_string,
var_list,
)
from infinigen.terrain.utils.kernelizer_util import SOCKETTYPE_KERNEL, SOCKETTYPES

functional_nodes = [
Nodes.SetPosition,
Expand All @@ -50,9 +50,10 @@ class Kernelizer:
def get_inputs(self, node_tree):
inputs = OrderedDict()
for node_input in ng_inputs(node_tree).values():
if node_input.type != SocketType.Geometry:
assert node_input.type != SocketType.Image
inputs[node_input.identifier] = SOCKETTYPE_KERNEL[node_input.type]
socket_type = SOCKETTYPES[node_input.socket_type]
if socket_type != SocketType.Geometry:
assert socket_type != SocketType.Image
inputs[node_input.identifier] = SOCKETTYPE_KERNEL[socket_type]
return inputs

def get_output(self, node_tree):
Expand All @@ -61,8 +62,9 @@ def get_output(self, node_tree):
if node.bl_idname == Nodes.SetPosition:
outputs[Vars.Offset] = KernelDataType.float3
for node_output in ng_outputs(node_tree).values():
if node_output.type != SocketType.Geometry:
outputs[node_output.identifier] = SOCKETTYPE_KERNEL[node_output.type]
socket_type = SOCKETTYPES[node_output.socket_type]
if socket_type != SocketType.Geometry:
outputs[node_output.identifier] = SOCKETTYPE_KERNEL[socket_type]
return outputs

def regularize(self, node_tree):
Expand Down Expand Up @@ -418,8 +420,7 @@ def __call__(self, modifier):
node_tree, collective_style=True
)
for nodeoutput in ng_outputs(node_tree).values():
id = nodeoutput.identifier
if id != "Output_1": # not Geometry
if nodeoutput.socket_type != "NodeSocketGeometry":
code = re.sub(rf"\b{id}\b", modifier[f"{id}_attribute_name"], code)
outputs[modifier[f"{id}_attribute_name"]] = outputs.pop(id)
return code, imp_inputs, outputs
8 changes: 8 additions & 0 deletions infinigen/terrain/utils/kernelizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ class KernelDataType:
# BOOL todo when necessary
}

SOCKETTYPES = {
"NodeSocketFloat": SocketType.Value,
"NodeSocketVector": SocketType.Vector,
"NodeSocketInt": SocketType.Int,
"NodeSocketColor": SocketType.RGBA,
"NodeSocketImage": SocketType.Image,
"NodeSocketGeometry": SocketType.Geometry,
}

NODE_FUNCTIONS = {
Nodes.WaveTexture: "node_shader_tex_wave",
Expand Down

0 comments on commit 6c4d2f3

Please sign in to comment.