Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Adding Spark34 support #2052

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9cb5879
chore: bump to spark 3.3.1
JessicaXYWang Oct 4, 2022
2708928
Adding Spark3.4 Support
KeerthiYandaOS Aug 14, 2023
515ecf8
Removing comment
KeerthiYandaOS Aug 14, 2023
bc2506c
Updating evn file
KeerthiYandaOS Aug 14, 2023
6f6f7b1
Fixing build errors
KeerthiYandaOS Aug 21, 2023
31cf6a5
Fixing build error
KeerthiYandaOS Aug 21, 2023
ef8809f
fixing style error
KeerthiYandaOS Aug 21, 2023
4c00917
Fixing scala errors
KeerthiYandaOS Aug 21, 2023
3fb63f1
Adding json
KeerthiYandaOS Aug 21, 2023
e4c484d
Adding play
KeerthiYandaOS Aug 21, 2023
c660b0a
Converting to json
KeerthiYandaOS Aug 21, 2023
6d3641d
Adding scalac plugin
KeerthiYandaOS Aug 22, 2023
48af251
Version update
KeerthiYandaOS Aug 22, 2023
431cf83
Updating scoverage version
KeerthiYandaOS Aug 22, 2023
891a5e5
Adding scalariform
KeerthiYandaOS Aug 22, 2023
2c4f42a
Adding versionScheme
KeerthiYandaOS Aug 22, 2023
bbb73a6
update version
KeerthiYandaOS Aug 22, 2023
c3b3be2
Updating the path
KeerthiYandaOS Aug 22, 2023
8291eb7
Update hadoop
KeerthiYandaOS Aug 22, 2023
a61d6cb
Updating breeze version
KeerthiYandaOS Aug 22, 2023
503703f
Remove scalactic exclusion
KeerthiYandaOS Aug 22, 2023
8c0365c
Adding scikit-learn
KeerthiYandaOS Aug 22, 2023
69ef571
removing versions
KeerthiYandaOS Aug 23, 2023
094ec40
Adding scikit-learn
KeerthiYandaOS Aug 23, 2023
7c477bd
try fix LightGBM unit test
svotaw Aug 29, 2023
2cc7b6d
Adding sklearn
KeerthiYandaOS Aug 23, 2023
f3898c5
Adding exclusion
KeerthiYandaOS Aug 27, 2023
0010c48
Exclude avro
KeerthiYandaOS Aug 28, 2023
137a087
exc sql
KeerthiYandaOS Aug 28, 2023
7a8cb14
Adding protobuf
KeerthiYandaOS Aug 29, 2023
78d2da6
Updating horovod to 0.28.1
KeerthiYandaOS Aug 29, 2023
4bb54d2
Update Quickstart - Document Question and Answering with PDFs.ipynb
BrendanWalsh Aug 31, 2023
4e2bcc3
Fixing Databricks Runtime and Tests
KeerthiYandaOS Aug 31, 2023
dc1b333
Adding Reccomendations fix
KeerthiYandaOS Aug 31, 2023
894bf96
Adding recomnd fix
KeerthiYandaOS Aug 31, 2023
efb770a
formating
KeerthiYandaOS Aug 31, 2023
f59a40c
Skipping SynapseE2E and Io2
KeerthiYandaOS Aug 31, 2023
3104981
Addressing comment
KeerthiYandaOS Sep 1, 2023
af104fb
Removed the commented code
KeerthiYandaOS Sep 1, 2023
d24bb83
Adding comment
KeerthiYandaOS Sep 1, 2023
b3545ec
Adding comments
KeerthiYandaOS Sep 1, 2023
c5ae641
Added default values
KeerthiYandaOS Sep 1, 2023
4ef41cd
Delete .github/workflows/ci-tests-r.yml
JessicaXYWang Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import scala.xml.transform.{RewriteRule, RuleTransformer}
import scala.xml.{Node => XmlNode, NodeSeq => XmlNodeSeq, _}

val condaEnvName = "synapseml"
val sparkVersion = "3.2.3"
val sparkVersion = "3.4.1"
name := "synapseml"
ThisBuild / organization := "com.microsoft.azure"
ThisBuild / scalaVersion := "2.12.15"
ThisBuild / scalaVersion := "2.12.17"

val scalaMajorVersion = 2.12

Expand All @@ -21,22 +21,24 @@ val excludes = Seq(
)

val coreDependencies = Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "compile",
// Excluding protobuf-java, as spark-core is bringing the older version transitively.
"org.apache.spark" %% "spark-core" % sparkVersion % "compile" exclude("com.google.protobuf", "protobuf-java"),
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
"org.apache.spark" %% "spark-mllib" % sparkVersion % "compile",
"org.apache.spark" %% "spark-avro" % sparkVersion % "provided",
"org.apache.spark" %% "spark-avro" % sparkVersion % "compile",
"org.apache.spark" %% "spark-tags" % sparkVersion % "test",
"org.scalatest" %% "scalatest" % "3.2.14" % "test")
val extraDependencies = Seq(
"commons-lang" % "commons-lang" % "2.6",
"org.scalactic" %% "scalactic" % "3.2.14",
"io.spray" %% "spray-json" % "1.3.5",
"com.jcraft" % "jsch" % "0.1.54",
"org.apache.httpcomponents.client5" % "httpclient5" % "5.1.3",
"org.apache.httpcomponents" % "httpmime" % "4.5.13",
"com.linkedin.isolation-forest" %% "isolation-forest_3.2.0" % "2.0.8",
// Although breeze 1.2 is already provided by Spark, this is needed for Azure Synapse Spark 3.2 pools.
// Otherwise a NoSuchMethodError will be thrown by interpretability code. This problem only happens
// to Azure Synapse Spark 3.2 pools.
"org.scalanlp" %% "breeze" % "1.2"
// As isolation-forest_3.2.0 is build for Spark32, excluding incompatable jars for Spark34.
"com.linkedin.isolation-forest" %% "isolation-forest_3.2.0" % "2.0.8" exclude("com.google.protobuf", "protobuf-java") exclude("org.apache.spark", "spark-mllib_2.12") exclude("org.apache.spark", "spark-core_2.12") exclude("org.apache.spark", "spark-avro_2.12") exclude("org.apache.spark", "spark-sql_2.12"),
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
// Although breeze 2.1.0 is already provided by Spark, this is needed for Azure Synapse Spark 3.4 pools.
// Otherwise a NoSuchMethodError will be thrown by interpretability code.
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
"org.scalanlp" %% "breeze" % "2.1.0"
).map(d => d excludeAll (excludes: _*))
val dependencies = coreDependencies ++ extraDependencies

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ object PyCodegen {
// There's `Already borrowed` error found in transformers 4.16.2 when using tokenizers
s"""extras_require={"extras": [
| "cmake",
| "horovod==0.25.0",
| "horovod==0.28.1",
| "pytorch_lightning>=1.5.0,<1.5.10",
| "torch==1.11.0",
| "torchvision>=0.12.0",
| "transformers==4.15.0",
| "torch==1.13.1",
| "torchvision>=0.14.1",
| "transformers==4.32.1",
| "petastorm>=0.12.0",
| "huggingface-hub>=0.8.1",
|]},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object PackageUtils {

val PackageName = s"synapseml_$ScalaVersionSuffix"
val PackageMavenCoordinate = s"$PackageGroup:$PackageName:${BuildInfo.version}"
private val AvroCoordinate = "org.apache.spark:spark-avro_2.12:3.3.1"
private val AvroCoordinate = "org.apache.spark:spark-avro_2.12:3.4.1"
val PackageRepository: String = SparkMLRepository

// If testing onnx package with snapshots repo, make sure to switch to using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package com.microsoft.azure.synapse.ml.exploratory

import breeze.stats.distributions.ChiSquared
import breeze.stats.distributions.{ChiSquared, RandBasis}
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
Expand Down Expand Up @@ -261,6 +261,7 @@ private[exploratory] case class DistributionMetrics(numFeatures: Int,

// Calculates left-tailed p-value from degrees of freedom and chi-squared test statistic
def chiSquaredPValue: Column = {
implicit val rand: RandBasis = RandBasis.mt0
val degOfFreedom = numFeatures - 1
val scoreCol = chiSquaredTestStatistic
val chiSqPValueUdf = udf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ private[ml] class HadoopFileReader(file: PartitionedFile,

private val iterator = {
val fileSplit = new FileSplit(
new Path(new URI(file.filePath)),
new Path(new URI(file.filePath.toString())),
file.start,
file.length,
Array.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

package com.microsoft.azure.synapse.ml.nn

import breeze.linalg.functions.euclideanDistance
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
import breeze.linalg.{DenseVector, norm, _}
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,20 @@ object SparkHelpers {

def flatten(ratings: Dataset[_], num: Int, dstOutputColumn: String, srcOutputColumn: String): DataFrame = {
import ratings.sparkSession.implicits._

val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
.toDF("id", "recommendations")
import org.apache.spark.sql.functions.{collect_top_k, struct}

val arrayType = ArrayType(
new StructType()
.add(dstOutputColumn, IntegerType)
.add("rating", FloatType)
.add(Constants.RatingCol, FloatType)
)
recs.select(col("id").as(srcOutputColumn), col("recommendations").cast(arrayType))

ratings.toDF(srcOutputColumn, dstOutputColumn, Constants.RatingCol).groupBy(srcOutputColumn)
.agg(collect_top_k(struct(Constants.RatingCol, dstOutputColumn), num, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.toDF(srcOutputColumn, Constants.Recommendations)
.withColumn(Constants.Recommendations, col(Constants.Recommendations).cast(arrayType))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logg
Iterator(emptyUnsafeRow)
} else {
val origin = file.filePath
val path = new Path(origin)
val path = new Path(origin.toString())
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
val stream = fs.open(path)
val bytes = try {
Expand All @@ -107,11 +107,12 @@ class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logg
IOUtils.close(stream)
}

val resultOpt = catchFlakiness(5)(ImageSchema.decode(origin, bytes)) //scalastyle:ignore magic.number
val resultOpt = catchFlakiness(5)( //scalastyle:ignore magic.number
ImageSchema.decode(origin.toString(), bytes))
val filteredResult = if (imageSourceOptions.dropInvalid) {
resultOpt.toIterator
} else {
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin.toString())))
}

if (requiredSchema.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ object RTestGen {
| "spark.sql.shuffle.partitions=10",
| "spark.sql.crossJoin.enabled=true")
|
|sc <- spark_connect(master = "local", version = "3.2.4", config = conf)
|sc <- spark_connect(master = "local", version = "3.4.1", config = conf)
|
|""".stripMargin, StandardOpenOption.CREATE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ abstract class TestBase extends AnyFunSuite with BeforeAndAfterEachTestData with
}

protected override def beforeAll(): Unit = {
System.setProperty("log4j1.compatibility", "true")
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
suiteElapsed = 0
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package com.microsoft.azure.synapse.ml.exploratory

import breeze.stats.distributions.ChiSquared
import breeze.stats.distributions.{ChiSquared, RandBasis}
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import org.apache.spark.sql.functions.{col, count, lit}
import org.apache.spark.sql.types.DoubleType
Expand Down Expand Up @@ -126,6 +126,7 @@ case class DistributionMetricsCalculator(refFeatureProbabilities: Array[Double],
val totalVariationDistance: Double = 0.5d * absDiffObsRef.sum
val wassersteinDistance: Double = absDiffObsRef.sum / absDiffObsRef.length
val chiSquaredTestStatistic: Double = (obsFeatureCounts, refFeatureCounts).zipped.map((a, b) => pow(a - b, 2) / b).sum
implicit val rand: RandBasis = RandBasis.mt0
val chiSquaredPValue: Double = chiSquaredTestStatistic match {
// limit of CDF as x approaches +inf is 1 (https://en.wikipedia.org/wiki/Cumulative_distribution_function)
case Double.PositiveInfinity => 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ import org.apache.spark.sql.functions.{col, length}
import org.apache.spark.sql.streaming.{DataStreamReader, DataStreamWriter, StreamingQuery}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
import org.json4s.DefaultFormats
import org.json4s.jackson.Serialization

import java.io.File
import java.util.UUID
import java.util.concurrent.{Executors, TimeUnit, TimeoutException}
import scala.concurrent.duration.{Duration, FiniteDuration}
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
import scala.util.parsing.json.JSONObject


trait HTTPTestUtils extends TestBase with WithFreeUrl {

Expand Down Expand Up @@ -81,7 +81,8 @@ trait HTTPTestUtils extends TestBase with WithFreeUrl {

def sendJsonRequest(map: Map[String, Any], url: String): String = {
val post = new HttpPost(url)
val params = new StringEntity(JSONObject(map).toString())
implicit val defaultFormats: DefaultFormats = DefaultFormats
val params = new StringEntity(Serialization.write(map))
post.addHeader("content-type", "application/json")
post.setEntity(params)
val res = RESTHelpers.Client.execute(post)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scala.language.existentials

class DatabricksCPUTests extends DatabricksTestHelper {

val clusterId: String = createClusterInPool(ClusterName, AdbRuntime, NumWorkers, PoolId, "[]")
val clusterId: String = createClusterInPool(ClusterName, AdbRuntime, NumWorkers, PoolId)
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(clusterId, Libraries, CPUNotebooks)

protected override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@ import java.io.File
import scala.collection.mutable.ListBuffer

class DatabricksGPUTests extends DatabricksTestHelper {
val horovodInstallationScript: File = FileUtilities.join(
BuildInfo.baseDirectory.getParent, "deep-learning",
"src", "main", "python", "horovod_installation.sh").getCanonicalFile
uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod-fix-commit/horovod_installation.sh")
val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId, GPUInitScripts)
val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId)
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(
clusterId, GPULibraries, GPUNotebooks)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ object DatabricksUtilities {

// ADB Info
val Region = "eastus"
val PoolName = "synapseml-build-10.4"
val GpuPoolName = "synapseml-build-10.4-gpu"
val AdbRuntime = "10.4.x-scala2.12"
val AdbGpuRuntime = "10.4.x-gpu-ml-scala2.12"
val PoolName = "synapseml-build-13.3"
val GpuPoolName = "synapseml-build-13.3-gpu"
val AdbRuntime = "13.3.x-scala2.12"
// https://docs.databricks.com/en/release-notes/runtime/13.3lts-ml.html
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
val AdbGpuRuntime = "13.3.x-gpu-ml-scala2.12"
val NumWorkers = 5
val AutoTerminationMinutes = 15

Expand Down Expand Up @@ -72,8 +73,11 @@ object DatabricksUtilities {
// TODO: install synapse.ml.dl wheel package here
val GPULibraries: String = List(
Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository)),
Map("pypi" -> Map("package" -> "transformers==4.15.0")),
Map("pypi" -> Map("package" -> "petastorm==0.12.0"))
Map("pypi" -> Map("package" -> "pytorch-lightning==1.5.0")),
Map("pypi" -> Map("package" -> "torchvision==0.14.1")),
Map("pypi" -> Map("package" -> "transformers==4.32.1")),
Map("pypi" -> Map("package" -> "petastorm==0.12.0")),
Map("pypi" -> Map("package" -> "protobuf==3.20.3"))
).toJson.compactPrint

val GPUInitScripts: String = List(
Expand Down Expand Up @@ -170,7 +174,7 @@ object DatabricksUtilities {
sparkVersion: String,
numWorkers: Int,
poolId: String,
initScripts: String): String = {
initScripts: String = "[]"): String = {
val body =
s"""
|{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ object SynapseExtensionUtilities {
val store = Secrets.ArtifactStore.capitalize
val excludes: String = "org.scala-lang:scala-reflect," +
"org.apache.spark:spark-tags_2.12," +
"org.scalactic:scalactic_2.12," +
"org.scalatest:scalatest_2.12," +
"org.slf4j:slf4j-api"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ object SynapseUtilities {
val excludes: String = Seq(
"org.scala-lang:scala-reflect",
"org.apache.spark:spark-tags_2.12",
"org.scalactic:scalactic_2.12",
"org.scalatest:scalatest_2.12",
"org.slf4j:slf4j-api").mkString(",")
val runName = abfssPath.split('/').last.replace(".py", "")
Expand Down Expand Up @@ -255,7 +254,7 @@ object SynapseUtilities {
| "nodeSizeFamily": "MemoryOptimized",
| "provisioningState": "Succeeded",
| "sessionLevelPackagesEnabled": "true",
| "sparkVersion": "3.2"
| "sparkVersion": "3.4"
| }
|}
|""".stripMargin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package com.microsoft.azure.synapse.ml.recommendation

import com.microsoft.azure.synapse.ml.core.test.fuzzing.{EstimatorFuzzing, TestObject, TransformerFuzzing}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.scalactic.Equality

class RankingAdapterSpec extends RankingTestBase with EstimatorFuzzing[RankingAdapter] {
override def testObjects(): Seq[TestObject[RankingAdapter]] = {
Expand All @@ -15,6 +17,19 @@ class RankingAdapterSpec extends RankingTestBase with EstimatorFuzzing[RankingAd

override def modelReader: MLReadable[_] = RankingAdapterModel

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
def prep(df: DataFrame) = {
// sort rows and round decimals before compare two dataframes
KeerthiYandaOS marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.spark.sql.functions._
val roundListDecimals: Seq[Float] => Seq[Float] = _.map { value =>
BigDecimal(value.toDouble).setScale(6, BigDecimal.RoundingMode.HALF_UP).toFloat
}
val castListToIntUDF = udf(roundListDecimals)
val sortedDF = df.orderBy(col("prediction"))
val updatedDF: DataFrame = sortedDF.withColumn("label", castListToIntUDF(col("label")))
updatedDF
}
super.assertDFEq(prep(df1), prep(df2))(eq)}
}

class RankingAdapterModelSpec extends RankingTestBase with TransformerFuzzing[RankingAdapterModel] {
Expand All @@ -24,4 +39,18 @@ class RankingAdapterModelSpec extends RankingTestBase with TransformerFuzzing[Ra
}

override def reader: MLReadable[_] = RankingAdapterModel

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
def prep(df: DataFrame) = {
// sort rows and round decimals before comparing dataframes
import org.apache.spark.sql.functions._
val roundListDecimals: Seq[Float] => Seq[Float] = _.map { value =>
BigDecimal(value.toDouble).setScale(6, BigDecimal.RoundingMode.HALF_UP).toFloat
}
val castListToIntUDF = udf(roundListDecimals)
val sortedDF = df.orderBy(col("prediction"))
val updatedDF: DataFrame = sortedDF.withColumn("label", castListToIntUDF(col("label")))
updatedDF
}
super.assertDFEq(prep(df1), prep(df2))(eq)}
}
13 changes: 5 additions & 8 deletions deep-learning/src/main/python/horovod_installation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ set -eu

# Install prerequisite libraries that horovod depends on
pip install pytorch-lightning==1.5.0
pip install torchvision==0.12.0
pip install transformers==4.15.0
pip install torchvision==0.14.1
pip install transformers==4.32.1
pip install petastorm>=0.12.0
pip install protobuf==3.20.3

Expand All @@ -35,16 +35,13 @@ libcusparse-dev-11-0=11.1.1.245-1

git clone --recursive https://github.com/horovod/horovod.git
cd horovod
# # fix version 0.25.0
# git fetch origin refs/tags/v0.25.0:tags/v0.25.0
# git checkout tags/v0.25.0 -b v0.25.0-branch
# fix to this commit number until they release a new version
git checkout ab97fd15bbba3258adcdd12983f36a1cdeacbc94
# git fetch origin refs/tags/v0.28.1:tags/v0.28.1
git checkout 1d217b59949986d025f6db93c49943fb6b6cc78f
git checkout -b tmp-branch
rm -rf build/ dist/
HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_CUDA_HOME=/usr/local/cuda-11/ HOROVOD_WITH_PYTORCH=1 HOROVOD_WITHOUT_MXNET=1 \
/databricks/python3/bin/python setup.py bdist_wheel

readlink -f dist/horovod-*.whl

pip install --no-cache-dir dist/horovod-0.25.0-cp38-cp38-linux_x86_64.whl --force-reinstall --no-deps
pip install --no-cache-dir dist/horovod-0.28.1-cp38-cp38-linux_x86_64.whl --force-reinstall --no-deps
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
if _TRANSFORMERS_AVAILABLE:
import transformers

_TRANSFORMERS_EQUAL_4_15_0 = transformers.__version__ == "4.15.0"
if _TRANSFORMERS_EQUAL_4_15_0:
_TRANSFORMERS_EQUAL_4_32_1 = transformers.__version__ == "4.32.1"
if _TRANSFORMERS_EQUAL_4_32_1:
from transformers import AutoTokenizer
else:
raise RuntimeError(
"transformers should be == 4.15.0, found: {}".format(
"transformers should be == 4.32.1, found: {}".format(
transformers.__version__
)
)
Expand Down
Loading