Skip to content

Commit

Permalink
Move compressionCodec to a string parameter
Browse files Browse the repository at this point in the history
This change allows Python (and presumably Java) users to use compressed
output by passing the compressionCodec as a string option, instead of an
argument to a function only visible in the implicit class.

This is an API breaking change, so I'm definitely open to other
approaches.

Author: Max Sperlich <[email protected]>

Closes #174 from msperlich/master.
  • Loading branch information
msperlich authored and falaki committed Nov 20, 2015
1 parent 472c20d commit 84858ce
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 9 deletions.
91 changes: 89 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ When reading files the API accepts several options:
* `charset`: defaults to 'UTF-8' but can be set to other valid charset names
* `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default
* `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`.
* `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified.

The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details.

Expand Down Expand Up @@ -127,6 +128,24 @@ selectedData.write
.save("newcars.csv")
```

You can save with compressed output:
```scala
import org.apache.spark.sql.SQLContext

val sqlContext = new SQLContext(sc)
val df = sqlContext.read
.format("com.databricks.spark.csv")
.option("header", "true") // Use first line of all files as header
.option("inferSchema", "true") // Automatically infer data types
.load("cars.csv")

val selectedData = df.select("year", "model")
selectedData.write
.format("com.databricks.spark.csv")
.option("header", "true")
.option("codec", "org.apache.hadoop.io.compress.GzipCodec")
.save("newcars.csv.gz")
```

__Spark 1.3:__

Expand Down Expand Up @@ -209,7 +228,23 @@ df.select("year", "model").write()
.save("newcars.csv");
```

You can save with compressed output:
```java
import org.apache.spark.sql.SQLContext

SQLContext sqlContext = new SQLContext(sc);
DataFrame df = sqlContext.read()
.format("com.databricks.spark.csv")
.option("inferSchema", "true")
.option("header", "true")
.load("cars.csv");

df.select("year", "model").write()
.format("com.databricks.spark.csv")
.option("header", "true")
.option("codec", "org.apache.hadoop.io.compress.GzipCodec")
.save("newcars.csv");
```

__Spark 1.3:__

Expand All @@ -235,7 +270,7 @@ import org.apache.spark.sql.types.*;

SQLContext sqlContext = new SQLContext(sc);
StructType customSchema = new StructType(
new StructField("year", IntegerType, true),
new StructField("year", IntegerType, true),
new StructField("make", StringType, true),
new StructField("model", StringType, true),
new StructField("comment", StringType, true),
Expand All @@ -250,6 +285,29 @@ DataFrame df = sqlContext.load("com.databricks.spark.csv", customSchema, options
df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv");
```

You can save with compressed output:
```java
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SaveMode;

SQLContext sqlContext = new SQLContext(sc);

HashMap<String, String> options = new HashMap<String, String>();
options.put("header", "true");
options.put("path", "cars.csv");
options.put("inferSchema", "true");

DataFrame df = sqlContext.load("com.databricks.spark.csv", options);

HashMap<String, String> saveOptions = new HashMap<String, String>();
saveOptions.put("header", "true");
saveOptions.put("path", "newcars.csv");
saveOptions.put("codec", "org.apache.hadoop.io.compress.GzipCodec");

df.select("year", "model").save("com.databricks.spark.csv", SaveMode.Overwrite,
saveOptions);
```

### Python API

__Spark 1.4+:__
Expand Down Expand Up @@ -286,6 +344,14 @@ df.select('year', 'model').write \
.save('newcars.csv')
```

You can save with compressed output:
```python
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

df = sqlContext.read.format('com.databricks.spark.csv').options(header='true', inferschema='true').load('cars.csv')
df.select('year', 'model').write.format('com.databricks.spark.csv').options(codec="org.apache.hadoop.io.compress.GzipCodec").save('newcars.csv')
```

__Spark 1.3:__

Expand Down Expand Up @@ -315,6 +381,15 @@ df = sqlContext.load(source="com.databricks.spark.csv", header = 'true', schema
df.select('year', 'model').save('newcars.csv', 'com.databricks.spark.csv')
```

You can save with compressed output:
```python
from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

df = sqlContext.load(source="com.databricks.spark.csv", header = 'true', inferSchema = 'true', path = 'cars.csv')
df.select('year', 'model').save('newcars.csv', 'com.databricks.spark.csv', codec="org.apache.hadoop.io.compress.GzipCodec")
```

### R API
__Spark 1.4+:__

Expand All @@ -337,7 +412,7 @@ library(SparkR)
Sys.setenv('SPARKR_SUBMIT_ARGS'='"--packages" "com.databricks:spark-csv_2.10:1.3.0" "sparkr-shell"')
sqlContext <- sparkRSQL.init(sc)
customSchema <- structType(
structField("year", "integer"),
structField("year", "integer"),
structField("make", "string"),
structField("model", "string"),
structField("comment", "string"),
Expand All @@ -348,5 +423,17 @@ df <- read.df(sqlContext, "cars.csv", source = "com.databricks.spark.csv", schem
write.df(df, "newcars.csv", "com.databricks.spark.csv", "overwrite")
```

You can save with compressed output:
```R
library(SparkR)

Sys.setenv('SPARKR_SUBMIT_ARGS'='"--packages" "com.databricks:spark-csv_2.10:1.2.0" "sparkr-shell"')
sqlContext <- sparkRSQL.init(sc)

df <- read.df(sqlContext, "cars.csv", source = "com.databricks.spark.csv", schema = customSchema, inferSchema = "true")

write.df(df, "newcars.csv", "com.databricks.spark.csv", "overwrite", codec="org.apache.hadoop.io.compress.GzipCodec")
```

## Building From Source
This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line-Reference.html), which is automatically downloaded by the included shell script. To build a JAR file simply run `sbt/sbt package` from the project root. The build configuration includes support for both Scala 2.10 and 2.11.
12 changes: 10 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class CsvParser extends Serializable {
private var parserLib: String = ParserLibs.DEFAULT
private var charset: String = TextFile.DEFAULT_CHARSET.name()
private var inferSchema: Boolean = false
private var codec: String = null

def withUseHeader(flag: Boolean): CsvParser = {
this.useHeader = flag
Expand Down Expand Up @@ -105,6 +106,11 @@ class CsvParser extends Serializable {
this
}

def withCompression(codec: String): CsvParser = {
this.codec = codec
this
}

/** Returns a Schema RDD for the given CSV path. */
@throws[RuntimeException]
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
Expand All @@ -122,7 +128,8 @@ class CsvParser extends Serializable {
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema)(sqlContext)
inferSchema,
codec)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}

Expand All @@ -141,7 +148,8 @@ class CsvParser extends Serializable {
ignoreTrailingWhiteSpace,
treatEmptyValuesAsNulls,
schema,
inferSchema)(sqlContext)
inferSchema,
codec)(sqlContext)
sqlContext.baseRelationToDataFrame(relation)
}
}
8 changes: 6 additions & 2 deletions src/main/scala/com/databricks/spark/csv/CsvRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ case class CsvRelation protected[spark] (
ignoreTrailingWhiteSpace: Boolean,
treatEmptyValuesAsNulls: Boolean,
userSchema: StructType = null,
inferCsvSchema: Boolean)(@transient val sqlContext: SQLContext)
inferCsvSchema: Boolean,
codec: String = null)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with PrunedScan with InsertableRelation {

/**
Expand Down Expand Up @@ -314,7 +315,10 @@ case class CsvRelation protected[spark] (
+ s" to INSERT OVERWRITE a CSV table:\n${e.toString}")
}
// Write the data. We assume that schema isn't changed, and we won't update it.
data.saveAsCsvFile(filesystemPath.toString, Map("delimiter" -> delimiter.toString))

val codecClass = compresionCodecClass(codec)
data.saveAsCsvFile(filesystemPath.toString, Map("delimiter" -> delimiter.toString),
codecClass)
} else {
sys.error("CSV tables only support INSERT OVERWRITE for now.")
}
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/com/databricks/spark/csv/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class DefaultSource
throw new Exception("Infer schema flag can be true or false")
}

val codec = parameters.getOrElse("codec", null)

CsvRelation(
() => TextFile.withCharset(sqlContext.sparkContext, path, charset),
Some(path),
Expand All @@ -151,7 +153,8 @@ class DefaultSource
ignoreTrailingWhiteSpaceFlag,
treatEmptyValuesAsNullsFlag,
schema,
inferSchemaFlag)(sqlContext)
inferSchemaFlag,
codec)(sqlContext)
}

override def createRelation(
Expand All @@ -178,7 +181,8 @@ class DefaultSource
}
if (doSave) {
// Only save data when the save mode is not ignore.
data.saveAsCsvFile(path, parameters)
val codecClass = compresionCodecClass(parameters.getOrElse("codec", null))
data.saveAsCsvFile(path, parameters, codecClass)
}

createRelation(sqlContext, parameters, data.schema)
Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/com/databricks/spark/csv/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ package object csv {
val defaultCsvFormat =
CSVFormat.DEFAULT.withRecordSeparator(System.getProperty("line.separator", "\n"))

private[csv] def compresionCodecClass(className: String): Class[_ <: CompressionCodec] = {
className match {
case null => null
case codec =>
// scalastyle:off classforname
Class.forName(codec).asInstanceOf[Class[CompressionCodec]]
// scalastyle:on classforname
}
}

/**
* Adds a method, `csvFile`, to SQLContext that allows reading CSV data.
*/
Expand Down Expand Up @@ -90,6 +100,8 @@ package object csv {

/**
* Saves DataFrame as csv files. By default uses ',' as delimiter, and includes header line.
* If compressionCodec is not null the resulting output will be compressed.
* Note that a codec entry in the parameters map will be ignored.
*/
def saveAsCsvFile(path: String, parameters: Map[String, String] = Map(),
compressionCodec: Class[_ <: CompressionCodec] = null): Unit = {
Expand Down
22 changes: 21 additions & 1 deletion src/test/scala/com/databricks/spark/csv/CsvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.Timestamp

import com.databricks.spark.csv.util.ParseModes
import org.apache.hadoop.io.compress.GzipCodec
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.{SQLContext, Row, SaveMode}
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfterAll, FunSuite}
Expand Down Expand Up @@ -431,6 +431,26 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet)
}

test("Scala API save with gzip compression codec") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
new File(tempEmptyDir).mkdirs()
val copyFilePath = tempEmptyDir + "cars-copy.csv"

val cars = sqlContext.csvFile(carsFile, parserLib = parserLib)
cars.save("com.databricks.spark.csv", SaveMode.Overwrite,
Map("path" -> copyFilePath, "header" -> "true", "codec" -> classOf[GzipCodec].getName))
val carsCopyPartFile = new File(copyFilePath, "part-00000.gz")
// Check that the part file has a .gz extension
assert(carsCopyPartFile.exists())

val carsCopy = sqlContext.csvFile(copyFilePath + "/")

assert(carsCopy.count == cars.count)
assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet)
}


test("DSL save with quoting") {
// Create temp directory
TestUtils.deleteRecursively(new File(tempEmptyDir))
Expand Down

0 comments on commit 84858ce

Please sign in to comment.