Skip to content

Commit 4531eb6

Browse files
yupbankthunterdb
authored andcommitted
Manually append meta data (#141)
* add append shape * add python bindings * fix the column * aaaa * address beatiful comments
1 parent eca77ca commit 4531eb6

File tree

4 files changed

+90
-4
lines changed

4 files changed

+90
-4
lines changed

src/main/python/tensorframes/core.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pyspark.sql.types import DoubleType, IntegerType, LongType, FloatType, ArrayType
1010

1111
__all__ = ['reduce_rows', 'map_rows', 'reduce_blocks', 'map_blocks',
12-
'analyze', 'print_schema', 'aggregate', 'block', 'row']
12+
'analyze', 'print_schema', 'aggregate', 'block', 'row',
13+
'append_shape']
1314

1415
_sc = None
1516
_sql = None
@@ -377,6 +378,26 @@ def analyze(dframe):
377378
"""
378379
return DataFrame(_java_api().analyze(dframe._jdf), _sql)
379380

381+
def append_shape(dframe, col, shape):
382+
"""Append extra metadata for a dataframe that
383+
describes the numerical shape of the content.
384+
385+
This method is useful when a dataframe contains non-scalar tensors, for which the shape must be checked beforehand.
386+
The user is responsible for providing the right shape, any mismatch will trigger eventually an exception in Spark
387+
388+
Note: nullable fields are not accepted.
389+
390+
The function [print_schema] lets users introspect the information added to the DataFrame.
391+
392+
:param dframe: a Spark DataFrame
393+
:param col: a Column expression
394+
:param shape: a shape corresponding to the tensor,
395+
detailed explanation https://www.tensorflow.org/programmers_guide/tensors#shape
396+
:return: a Spark DataFrame with metadata information embedded.
397+
"""
398+
shape = [i or -1 for i in shape]
399+
return DataFrame(_java_api().appendShape(dframe._jdf, col._jc, shape), _sql)
400+
380401
def aggregate(fetches, grouped_data, initial_variables=_initial_variables_default):
381402
"""
382403
Performs an algebraic aggregation on the grouped data.

src/main/python/tensorframes/core_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pyspark import SparkContext
55
from pyspark.sql import DataFrame, SQLContext
66
from pyspark.sql import Row
7+
from pyspark.sql.functions import col
78
import tensorflow as tf
89
import pandas as pd
910

@@ -198,6 +199,20 @@ def test_reduce_rows_1(self):
198199
res = tfs.reduce_rows(x, df)
199200
assert res == sum([r.x for r in data])
200201

202+
def test_append_shape(self):
203+
data = [Row(x=float(x)) for x in range(5)]
204+
df = self.sql.createDataFrame(data)
205+
ddf = tfs.append_shape(df, col('x'), [-1])
206+
with tf.Graph().as_default():
207+
# The placeholder that corresponds to column 'x'
208+
x_1 = tf.placeholder(tf.double, shape=[], name="x_1")
209+
x_2 = tf.placeholder(tf.double, shape=[], name="x_2")
210+
# The output that adds 3 to x
211+
x = tf.add(x_1, x_2, name='x')
212+
# The resulting number
213+
res = tfs.reduce_rows(x, ddf)
214+
assert res == sum([r.x for r in data])
215+
201216
# This test fails
202217
def test_reduce_blocks_1(self):
203218
data = [Row(x=float(x)) for x in range(5)]

src/main/scala/org/tensorframes/ExperimentalOperations.scala

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package org.tensorframes
22

3+
import java.util
4+
import org.apache.spark.sql.Column
35
import org.apache.spark.sql.DataFrame
46
import org.apache.spark.sql.functions.col
5-
import org.apache.spark.sql.types.{ArrayType, DataType, NumericType}
6-
7+
import org.apache.spark.sql.types.{ArrayType, DataType, MetadataBuilder, NumericType}
78
import org.tensorframes.impl.{ScalarType, SupportedOperations}
9+
import scala.collection.JavaConverters._
10+
11+
812

913
/**
1014
* Some useful methods for operating on dataframes that are not part of the official API (and thus may change anytime).
@@ -45,6 +49,23 @@ trait ExperimentalOperations {
4549
}
4650
df.select(cols: _*)
4751
}
52+
53+
def appendShape(df: DataFrame, col: Column, shape: Array[Int]): DataFrame = {
54+
55+
val meta = new MetadataBuilder
56+
val colDtypes = df.select(col).schema.fields.head.dataType
57+
val basicDatatype =
58+
ExtraOperations.extractBasicType(colDtypes).getOrElse(throw new Exception(s"'$colDtypes' was not supported"))
59+
60+
meta.putString(MetadataConstants.tensorStructType,
61+
SupportedOperations.opsFor(basicDatatype).sqlType.toString
62+
)
63+
meta.putLongArray(MetadataConstants.shapeKey, shape.map(_.asInstanceOf[Long]))
64+
df.withColumn(col.toString(), col.as("", meta.build()))
65+
}
66+
67+
def appendShape(df: DataFrame, col:Column, shape: util.ArrayList[Int]): DataFrame =
68+
appendShape(df, col, shape.asScala.toArray[Int])
4869
}
4970

5071
private[tensorframes] object ExtraOperations extends ExperimentalOperations with Logging {
@@ -110,7 +131,7 @@ private[tensorframes] object ExtraOperations extends ExperimentalOperations with
110131
DataFrameInfo(allInfo)
111132
}
112133

113-
private def extractBasicType(dt: DataType): Option[ScalarType] = dt match {
134+
def extractBasicType(dt: DataType): Option[ScalarType] = dt match {
114135
case x: NumericType => Some(SupportedOperations.opsFor(x).scalarType)
115136
case x: ArrayType => extractBasicType(x.elementType)
116137
case _ => None

src/test/scala/org/tensorframes/BasicOperationsSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.tensorframes
33
import org.scalatest.FunSuite
44

55
import org.apache.spark.sql.Row
6+
import org.apache.spark.sql.functions.col
67

78
import org.tensorframes.dsl.Implicits._
89
import org.tensorframes.dsl._
@@ -42,6 +43,15 @@ class BasicOperationsSuite
4243
compareRows(df2.collect(), Array(Row(Seq(1.0), Seq(1.0)), Row(Seq(2.0), Seq(2.0))))
4344
}
4445

46+
testGraph("Identity - 1 dim, Manually") {
47+
val df = make1(Seq(Seq(1.0), Seq(2.0)), "in")
48+
val adf = ops.appendShape(df, col("in"), Array(-1, 1))
49+
val p1 = placeholder[Double](Unknown, 1) named "in"
50+
val out = identity(p1) named "out"
51+
val df2 = adf.mapBlocks(out).select("in", "out")
52+
compareRows(df2.collect(), Array(Row(Seq(1.0), Seq(1.0)), Row(Seq(2.0), Seq(2.0))))
53+
}
54+
4555
testGraph("Simple add - 1 dim") {
4656
val a = placeholder[Double](Unknown, 1) named "a"
4757
val b = placeholder[Double](Unknown, 1) named "b"
@@ -57,6 +67,25 @@ class BasicOperationsSuite
5767
Row(Seq(2.0), Seq(2.2), Seq(4.2))))
5868
}
5969

70+
testGraph("Simple add - 1 dim, Manually") {
71+
val a = placeholder[Double](Unknown, 1) named "a"
72+
val b = placeholder[Double](Unknown, 1) named "b"
73+
val out = a + b named "out"
74+
75+
val df = sql.createDataFrame(Seq(
76+
Seq(1.0)->Seq(1.1),
77+
Seq(2.0)->Seq(2.2))).toDF("a", "b")
78+
val adf = {
79+
ops.appendShape(
80+
ops.appendShape(df, col("a"), Array(-1, 1)),
81+
col("b"), Array(-1, 1))
82+
}
83+
val df2 = adf.mapBlocks(out).select("a", "b","out")
84+
compareRows(df2.collect(), Array(
85+
Row(Seq(1.0), Seq(1.1), Seq(2.1)),
86+
Row(Seq(2.0), Seq(2.2), Seq(4.2))))
87+
}
88+
6089
testGraph("Reduce - sum double") {
6190
val df = make1(Seq(1.0, 2.0), "x")
6291
val x1 = placeholder[Double]() named "x_1"

0 commit comments

Comments
 (0)