Прогнозирование с помощью Apache Spark MLJAVA

Программисты JAVA общаются здесь
Ответить
Anonymous
 Прогнозирование с помощью Apache Spark ML

Сообщение Anonymous »

Я новичок в Apache Spark ML.
Я хотел бы получить прогноз баланса по возрасту и стране. В качестве входных данных у меня есть файл CSV в следующем формате: RowNumber,Age,Country,Balance.
Модель построена и также может быть обучена на тестовых данных. Пока все работает.
Моя проблема сейчас в том, что я хочу сделать прогноз для новой записи клиента.
Dataset newCustomer = spark.createDataFrame(Collections.singletonList(
new Customer(28, ‘Germany’)), Customer.class);
Dataset newCustomerPrediction = model.transform(newCustomer);

Я получаю следующее сообщение об ошибке:
java.lang.IllegalArgumentException: CountryIndex не существует. Доступно: возраст, страна.
Как я могу получить прогноз для нового набора данных?
Привет,
Марио
public static void main(String[] args) {

SparkSession spark = SparkSession
.builder()
.master("local[*]")
.appName("JavaGeneralizedLinearRegressionExample")
.getOrCreate();

Dataset data = spark.read()
.option("header", "true")
.option("inferSchema", "true")
.option("delimiter", ",") // oder "," je nach Dateiformat
.csv("/data/testdaten_v4.csv");

StringIndexer countryIndexer = new StringIndexer()
.setInputCol("Country")
.setOutputCol("CountryIndex")
.setHandleInvalid("skip");
OneHotEncoder countryEncoder = new OneHotEncoder()
.setInputCol("CountryIndex")
.setOutputCol("CountryVec");

VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"Age", "CountryVec"}) // andere Features hinzufügen falls nötig
.setOutputCol("features");

StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures");

LinearRegression lr = new LinearRegression()
.setLabelCol("Balance")
.setFeaturesCol("scaledFeatures")
.setMaxIter(100)
.setRegParam(0.3)
.setElasticNetParam(0.8);

Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{countryIndexer, countryEncoder, assembler, scaler, lr});

PipelineModel model = pipeline.fit(data);

Dataset[] splits = data.randomSplit(new double[]{0.8, 0.2}, 42);
Dataset trainData = splits[0];
Dataset testData = splits[1];

Dataset predictions = model.transform(testData);
predictions.select("Age", "Country", "Balance", "prediction").show();

RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("Balance")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);

Dataset newCustomer = spark.createDataFrame(Collections.singletonList(
new Customer(28, "Germany")), Customer.class);
Dataset newCustomerPrediction = model.transform(newCustomer);
newCustomerPrediction.select("prediction").show();

spark.stop();
}

public static class Customer {
private int Age;
private String Country;

public Customer(int age, String country) {
this.Age = age;
this.Country = country;
}

public int getAge() { return Age; }
public String getCountry() { return Country; }
}


Подробнее здесь: https://stackoverflow.com/questions/791 ... e-spark-ml
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «JAVA»