The Steps to Create, Train, Save, and Load a Spam Detection AI Model Using ML.NET
This article demonstrates the process of creating, training, saving, and loading a spam detection AI model using ML.NET, but also emphasizes the reusability of the trained model. By following the steps in the article, you will be able to create a model that can be easily reused and integrated into your .NET applications, allowing you to effectively identify and filter out spam emails.
Prerequisites
- Basic understanding of C#
- Familiarity with ML.NET and machine learning concepts
Code Overview
-
- Import necessary namespaces:
using System;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
-
- Define the
Email
class and its properties:
- Define the
public class Email
{
public string Content { get; set; }
public bool IsSpam { get; set; }
}
-
- Create a sample dataset for training the model:
var sampleData = new List<Email>
{
new Email { Content = "Buy cheap products now", IsSpam = true },
new Email { Content = "Meeting at 3 PM", IsSpam = false },
};
-
- Initialize a new MLContext, which is the main entry point to ML.NET:
var mlContext = new MLContext();
-
- Load the sample data into an IDataView:
var trainData = mlContext.Data.LoadFromEnumerable(sampleData);
-
- Define the data processing pipeline and the training algorithm (SdcaLogisticRegression):
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", nameof(Email.Content))
.Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression());
-
- Train the model:
var model = pipeline.Fit(trainData);
-
- Save the trained model as a .NET binary:
mlContext.Model.Save(model, trainData.Schema, "model.zip");
-
- Load the saved model:
var newMlContext = new MLContext();
DataViewSchema modelSchema;
ITransformer trainedModel = newMlContext.Model.Load("model.zip", out modelSchema);
-
- Create a prediction engine:
var predictionEngine = mlContext.Model.CreatePredictionEngine<Email, SpamPrediction>(trainedModel);
-
- Test the model with a sample email:
var sampleEmail = new Email { Content = "Special discount, buy now!" };
var prediction = predictionEngine.Predict(sampleEmail);
-
- Output the prediction:
Debug.WriteLine($"Email: '{sampleEmail.Content}' is {(prediction.IsSpam ? "spam" : "not spam")}");
-
- Assert that the prediction is correct:
Assert.IsTrue(prediction.IsSpam);
-
- Verify that the model was saved:
if(File.Exists("model.zip"))
Assert.Pass();
else
Assert.Fail();
Conclusion
In this article, we explained a simple spam detection model in ML.NET and demonstrated how to train and test the model. This code can be extended to build more complex models, and can be used as a starting point for exploring machine learning in .NET.