Brevemente sobre la biblioteca Rumale para aprendizaje automático en Ruby / Sudo Null IT News

¡Hola Habr!

la biblioteca Eres hombre creado para hacer que el aprendizaje automático sea accesible y conveniente para los desarrolladores de Ruby. Tiene una gran selección de algoritmos y herramientas similares a las que se encuentran en Scikit-learn para Python.

Se eligió el formato breve del artículo debido a las similitudes con Sckit learn.

podemos descubrir

Abra Gemfile y agregue la línea:

gem 'rumale'

Después de eso usamos bundle install para instalar la biblioteca:

$ bundle install

Si desea instalar Rumale sin Bundler, puede hacerlo directamente mediante el comando gem install:

$ gem install rumale

Después de instalar la biblioteca, conéctela al proyecto:

require 'rumale'

Construcción y entrenamiento de modelos en Rumale.

Cargaremos datos usando las bibliotecas Daru y RDatasets.

Regresión lineal

La regresión lineal es una base para predecir valores numéricos. Rumale usa la clase para este propósito. Rumale::LinearModel::LinearRegression:

require 'daru'
require 'rumale'

# создание набора данных
data = Daru::DataFrame.from_csv('housing_prices.csv')
x = data('size').to_a
y = data('price').to_a

# преобразование данных в формат, подходящий для Rumale
x = Numo::DFloat(x).reshape(x.size, 1)
y = Numo::DFloat(y)

# построение и обучение модели линейной регрессии
model = Rumale::LinearModel::LinearRegression.new
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

Los datos sobre el tamaño y el precio de la vivienda se descargan de un archivo CSV, se convierten en matrices y luego se utilizan para entrenar un modelo de regresión lineal.

Máquina de vectores de soporte (SVM)

Support Vector Machine es un algoritmo para problemas de clasificación. En Rumale está representado por la clase. Rumale::LinearModel::SVC:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris(0..3).to_matrix
y = iris('Species').map { |species| species == 'setosa' ? 0 : 1 }

# преобразование данных в формат Numo::NArray
x = Numo::DFloat(*x.to_a)
y = Numo::Int32(*y)

# построение и обучение модели SVM
model = Rumale::LinearModel::SVC.new(kernel: 'linear', reg_param: 1.0)
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

El modelo SVM clasifica las flores como setosa O no.

Agrupación utilizando K-Means

K-Means es un algoritmo de agrupación que agrupa datos en función de similitudes. Rumale usa la clase Rumale::Clustering::KMeans:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris(0..3).to_matrix

# преобразование данных в формат Numo::NArray
x = Numo::DFloat(*x.to_a)

# построение и обучение модели K-Means
model = Rumale::Clustering::KMeans.new(n_clusters: 3, max_iter: 300)
model.fit(x)

# предсказание кластеров
labels = model.predict(x)
puts "Кластеры: #{labels.to_a}"

Usamos los datos de Iris para agruparlos en tres grupos usando K-Means.

Otros algoritmos

Bosque aleatorio:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris(0..3).to_matrix
y = iris('Species').map { |species| species == 'setosa' ? 0 : 1 }

# преобразование данных в формат Numo::NArray
x = Numo::DFloat(*x.to_a)
y = Numo::Int32(*y)

# построение и обучение модели Random Forest
model = Rumale::Ensemble::RandomForestClassifier.new(n_estimators: 10, max_depth: 3)
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

Aumento de gradiente:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris(0..3).to_matrix
y = iris('Species').map { |species| species == 'setosa' ? 0 : 1 }

# преобразование данных в формат Numo::NArray
x = Numo::DFloat(*x.to_a)
y = Numo::Int32(*y)

# построение и обучение модели Gradient Boosting
model = Rumale::Ensemble::GradientBoostingClassifier.new(n_estimators: 100, learning_rate: 0.1, max_depth: 3)
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

Evaluación y validación de modelos.

Métricas para evaluar la calidad de los modelos.

Error cuadrático medio (MSE): Mide el promedio de errores al cuadrado, es decir, la diferencia entre los valores previstos y reales:

require 'numo/narray'
require 'rumale'

# пример данных
y_true = Numo::DFloat(3.0, -0.5, 2.0, 7.0)
y_pred = Numo::DFloat(2.5, 0.0, 2.0, 8.0)

# расчет MSE
mse = Rumale::EvaluationMeasure::MeanSquaredError.new
mse_value = mse.score(y_true, y_pred)
puts "MSE: #{mse_value}"

Coeficiente de determinación (R²): Mide la proporción de varianza explicada por el modelo. El valor R² varía de 0 a 1, siendo 1 un ajuste perfecto:

# расчет R²
r2 = Rumale::EvaluationMeasure::RSquared.new
r2_value = r2.score(y_true, y_pred)
puts "R²: #{r2_value}"

Validación cruzada

La validación cruzada nos permite evaluar la capacidad de generalización de un modelo. Uno de los métodos más comunes es la validación cruzada K-Fold.

Validación cruzada de K-Fold:

require 'rumale'
require 'daru'
require 'rdatasets'

# загрузка данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris(0..3).to_matrix
y = iris('Species').map { |species| species == 'setosa' ? 0 : 1 }

x = Numo::DFloat(*x.to_a)
y = Numo::Int32(*y)

# определение модели
model = Rumale::LinearModel::LogisticRegression.new

# определение метрики оценки
mse = Rumale::EvaluationMeasure::MeanSquaredError.new

# настройка K-Fold кросс-валидации
kf = Rumale::ModelSelection::KFold.new(n_splits: 5, shuffle: true, random_seed: 1)

# проведение кросс-валидации
cv = Rumale::ModelSelection::CrossValidation.new(estimator: model, splitter: kf, evaluator: mse)
report = cv.perform(x, y)

# вывод результатов
mean_score = report(:test_score).sum / kf.n_splits
puts "5-CV MSE: #{mean_score}"

Después de realizar una validación cruzada u otros métodos de evaluación, es muy importante no olvidar que también es necesario interpretar correctamente los resultados obtenidos.

Media y desviación estándar: Estos indicadores dan una idea de la estabilidad y fiabilidad del modelo. Por ejemplo, promedio bajo. un valor de error y una desviación estándar baja indican un modelo estable y preciso:

mean_score = report(:test_score).mean
std_score = report(:test_score).std
puts "Mean MSE: #{mean_score}, Standard Deviation: #{std_score}"

Aún puedes conectarte trama gnuplot, para visualizar y ayudar a comprender el rendimiento del modelo en varios conjuntos de datos:

require 'gnuplot'

Gnuplot.open do |gp|
  Gnuplot::Plot.new(gp) do |plot|
    plot.title "K-Fold Cross Validation Scores"
    plot.ylabel "MSE"
    plot.xlabel "Fold"

    plot.data << Gnuplot::DataSet.new(report(:test_score)) do |ds|
      ds.with = "linespoints"
      ds.title = "Fold MSE"
    end
  end
end

Aprende más con esta maravillosa biblioteca se puede encontrar aquí.

Y siempre podrá familiarizarse con otras herramientas y bibliotecas dentro cursos prácticos en línea de mis colegas de OTUS.

Publicaciones Similares

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *