Spark隨機(jī)森林實(shí)現(xiàn)票房預(yù)測(cè)
前言
最近一段時(shí)間都在處理電影領(lǐng)域的數(shù)據(jù), 而電影票房預(yù)測(cè)是電影領(lǐng)域數(shù)據(jù)建模中的一個(gè)重要模塊, 所以我們針對(duì)電影數(shù)據(jù)做了票房預(yù)測(cè)建模.
前期工作
一開(kāi)始的做法是將這個(gè)問(wèn)題看待成回歸的問(wèn)題, 采用GBDT回歸樹(shù)去做. 訓(xùn)練了不同殘差的回歸樹(shù), 然后做集成學(xué)習(xí). 考慮的影響因子分別有電影的類(lèi)型, 豆瓣評(píng)分, 導(dǎo)演的 影響力, 演員的影響力, 電影的出品公司. 不過(guò)預(yù)測(cè)的結(jié)果并不是那么理想, 準(zhǔn)確率為真實(shí)值的0.3+/-區(qū)間情況下的80%, 且波動(dòng)性較大, 不容易解析.
后期的改進(jìn)
總結(jié)之前的失敗經(jīng)驗(yàn), 主要?dú)w納了以下幾點(diǎn):
1.影響因子不夠多, 難以建模
2.票房成績(jī)的區(qū)間較大(一百萬(wàn)到10億不等),分布不均勻, 大多數(shù)集中與億級(jí), 所以不適合采用回歸方法解決.
3.數(shù)據(jù)樣本量比較少, 不均勻, 預(yù)測(cè)百萬(wàn)級(jí)的電影較多, 影響預(yù)測(cè)結(jié)果
后期, 我們重新規(guī)范了數(shù)據(jù)的輸入格式, 即影響因子, 具體如下:
第一行: 電影名字
第二行: 電影票房(也就是用于預(yù)測(cè)的, 以萬(wàn)為單位)
第三行: 電影類(lèi)型
第四行: 片長(zhǎng)(以分鐘為單位)
第五行:上映時(shí)間(按月份)
第六行: 制式( 一般分為2D, 3D, IMAX)
第七行: 制作國(guó)家
第八行: 導(dǎo)演影響 (以導(dǎo)演的平均票房成績(jī)?yōu)楹饬? 以萬(wàn)為單位 )
第九行: 演員影響 ( 以所有演員的平均票房成績(jī)?yōu)楹饬? 以萬(wàn)為單位 )
第十行:制作公司影響 ( 以所有制作公司的平均票房成績(jī)?yōu)楹饬? 以萬(wàn)為單位 )
第十一行: 發(fā)行公式影響 ( 以所有制作公司的平均票房成績(jī)?yōu)楹饬?以萬(wàn)為單位 )
收集了05-17年的來(lái)自中國(guó),日本,美國(guó),英國(guó)的電影, 共1058部電影. 由于處理成為分類(lèi)問(wèn)題, 故按將電影票房分為以下等級(jí):

在構(gòu)建模型之前, 先將數(shù)據(jù)處理成libsvm格式文件, 然后采用隨機(jī)森林模型訓(xùn)練.
隨機(jī)森林由許多的決策樹(shù)組成, 因?yàn)檫@些決策樹(shù)的形成采用隨機(jī)的策略, 每個(gè)決策樹(shù)都隨機(jī)生成, 相互之間獨(dú)立.模型最后輸出的類(lèi)別是由每個(gè)樹(shù)輸出的類(lèi)別的眾數(shù)而定.在構(gòu)建每個(gè)決策樹(shù)的時(shí)候采用的策略是信息熵, 決策樹(shù)為多元分類(lèi)決策樹(shù).隨機(jī)森林的流程圖如下圖所示:

隨機(jī)森林是采用spark-mllib提供的random forest, 由于超過(guò)10億的電影的數(shù)據(jù)相對(duì)比較少, 為了平衡各數(shù)據(jù)的分布, 采用了過(guò)分抽樣的方法, 訓(xùn)練模型的代碼如下:
public void predict() throws IOException{
SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
conf.set("spark.testing.memory", "2147480000");
SparkContext sc = new SparkContext(conf);
SQLContext sqlContext = new SQLContext(sc);
// Load and parse the data file, converting it to a DataFrame.
DataFrame trainData = sqlContext.read().format("libsvm").load(this.trainFile);
DataFrame testData = sqlContext.read().format("libsvm").load(this.testFile);
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
StringIndexerModel labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(trainData);
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
VectorIndexerModel featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(trainData);
// Split the data into training and test sets (30% held out for testing)
// DataFrame[] splits = trainData.randomSplit(new double[] {0.9, 0.1});
// trainData = splits[0];
// testData = splits[1];
// Train a RandomForest model.
RandomForestClassifier rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(20);
// Convert indexed labels back to original labels.
IndexToString labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels());
// Chain indexers and forest in a Pipeline
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter});
// Train model. This also runs the indexers.
PipelineModel model = pipeline.fit(trainData);
// Make predictions.
DataFrame predictions = model.transform(testData);
// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(200);
// Select (prediction, true label) and compute test error
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("precision");
double accuracy = evaluator.evaluate(predictions);
System.out.println("Test Error = " + (1.0 - accuracy));
RandomForestClassificationModel rfModel = (RandomForestClassificationModel)(model.stages()[2]);
// System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
DataFrame resultDF = predictions.select("predictedLabel");
JavaRDD<Row> resultRow = resultDF.toJavaRDD();
JavaRDD<String> result = resultRow.map(new Result());
this.resultList = result.collect();
for(String one: resultList){
System.out.println(one);
}
}
下面為其中一個(gè)的決策樹(shù)情況:
Tree 16 (weight 1.0):
If (feature 10 in {0.0})
If (feature 48 <= 110.0)
If (feature 86 <= 13698.87)
If (feature 21 in {0.0})
If (feature 54 in {0.0})
Predict: 0.0
Else (feature 54 not in {0.0})
Predict: 1.0
Else (feature 21 not in {0.0})
Predict: 0.0
Else (feature 86 > 13698.87)
If (feature 21 in {0.0})
If (feature 85 <= 39646.9)
Predict: 2.0
Else (feature 85 > 39646.9)
Predict: 3.0
Else (feature 21 not in {0.0})
Predict: 3.0
Else (feature 48 > 110.0)
If (feature 85 <= 15003.3)
If (feature 9 in {0.0})
If (feature 54 in {0.0})
Predict: 0.0
Else (feature 54 not in {0.0})
Predict: 2.0
Else (feature 9 not in {0.0})
Predict: 2.0
Else (feature 85 > 15003.3)
If (feature 65 in {0.0})
If (feature 85 <= 66065.0)
Predict: 3.0
Else (feature 85 > 66065.0)
Predict: 2.0
Else (feature 65 not in {0.0})
Predict: 3.0
Else (feature 10 not in {0.0})
If (feature 51 in {0.0})
If (feature 85 <= 6958.4)
If (feature 11 in {0.0})
If (feature 50 <= 1.0)
Predict: 1.0
Else (feature 50 > 1.0)
Predict: 0.0
Else (feature 11 not in {0.0})
Predict: 0.0
Else (feature 85 > 6958.4)
If (feature 5 in {0.0})
If (feature 4 in {0.0})
Predict: 3.0
Else (feature 4 not in {0.0})
Predict: 1.0
Else (feature 5 not in {0.0})
Predict: 2.0
Else (feature 51 not in {0.0})
If (feature 48 <= 148.0)
If (feature 0 in {0.0})
If (feature 6 in {0.0})
Predict: 2.0
Else (feature 6 not in {0.0})
Predict: 0.0
Else (feature 0 not in {0.0})
If (feature 50 <= 4.0)
Predict: 2.0
Else (feature 50 > 4.0)
Predict: 3.0
Else (feature 48 > 148.0)
If (feature 9 in {0.0})
If (feature 49 <= 3.0)
Predict: 2.0
Else (feature 49 > 3.0)
Predict: 0.0
Else (feature 9 not in {0.0})
If (feature 36 in {0.0})
Predict: 3.0
Else (feature 36 not in {0.0})
Predict: 2.0
后記
該模型預(yù)測(cè)的平均準(zhǔn)確率為80%, 但相對(duì)之前的做法規(guī)范了很多, 對(duì)結(jié)果的解析也更加的合理, 不過(guò)如何增強(qiáng)預(yù)測(cè)的效果, 可以考慮更多的因子, 形如:電影是否有前續(xù);電影網(wǎng)站的口碑指數(shù);預(yù)告片的播放量;相關(guān)微博的閱讀數(shù);百度指數(shù)等;
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
Java 中如何使用 JavaFx 庫(kù)標(biāo)注文本顏色
這篇文章主要介紹了在 Java 中用 JavaFx 庫(kù)標(biāo)注文本顏色,在本文中,我們將了解如何更改標(biāo)簽的文本顏色,并且我們還將看到一個(gè)必要的示例和適當(dāng)?shù)慕忉專(zhuān)员愀菀桌斫庠撝黝},需要的朋友可以參考下2023-05-05
Java中Comparator與Comparable排序的區(qū)別詳解
這篇文章主要介紹了Java中Comparator與Comparable排序的區(qū)別詳解,如果你有一個(gè)類(lèi),希望支持同類(lèi)型的自定義比較策略,可以實(shí)現(xiàn)接口Comparable,如果某個(gè)類(lèi),沒(méi)有實(shí)現(xiàn)Comparable,但是又希望對(duì)它進(jìn)行比較,則可以自定義一個(gè)Comparator,需要的朋友可以參考下2024-01-01
MyBatis-plus+達(dá)夢(mèng)數(shù)據(jù)庫(kù)實(shí)現(xiàn)自動(dòng)生成代碼的示例
這篇文章主要介紹了MyBatis-plus+達(dá)夢(mèng)數(shù)據(jù)庫(kù)實(shí)現(xiàn)自動(dòng)生成代碼的示例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08
Java簡(jiǎn)易登錄注冊(cè)功能實(shí)現(xiàn)代碼解析
這篇文章主要介紹了Java簡(jiǎn)易登錄注冊(cè)功能實(shí)現(xiàn)代碼解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-06-06
SpringBoot項(xiàng)目讀取外置logback配置文件的問(wèn)題及解決
SpringBoot項(xiàng)目讀取外置logback配置文件的問(wèn)題及解決,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2022-08-08
Java8方法引用及構(gòu)造方法引用原理實(shí)例解析
這篇文章主要介紹了Java8方法引用及構(gòu)造方法引用原理實(shí)例解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-09-09
MyBatis中的SQL映射文件如何配置參數(shù)映射和使用方法
MyBatis 是一種開(kāi)源的 Java 持久化框架,它可以自動(dòng)將數(shù)據(jù)庫(kù)中的數(shù)據(jù)映射到 Java 對(duì)象中,并且使得 Java 對(duì)象可以非常方便地存儲(chǔ)到數(shù)據(jù)庫(kù)中,本文將介紹 MyBatis 中 SQL 映射文件的參數(shù)映射配置和使用方法,需要的朋友可以參考下2023-07-07
JavaWeb 文件的上傳和下載功能簡(jiǎn)單實(shí)現(xiàn)代碼
這篇文章主要介紹了JavaWeb 文件的上傳和下載功能簡(jiǎn)單實(shí)現(xiàn)代碼,需要的朋友可以參考下2017-04-04

