■ 다중 클래스 분류 모델을 사용하는 방법을 보여준다.
▶ IssueData.cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
using Microsoft.ML.Data; namespace TestProject { /// <summary> /// 이슈 데이터 /// </summary> public class IssueData { //////////////////////////////////////////////////////////////////////////////////////////////////// Property ////////////////////////////////////////////////////////////////////////////////////////// Public #region ID - ID /// <summary> /// ID /// </summary> [LoadColumn(0)] public string ID { get; set; } #endregion #region 영역 - Area /// <summary> /// 영역 /// </summary> [LoadColumn(1)] public string Area { get; set; } #endregion #region 제목 - Title /// <summary> /// 제목 /// </summary> [LoadColumn(2)] public string Title { get; set; } #endregion #region 설명 - Description /// <summary> /// 설명 /// </summary> [LoadColumn(3)] public string Description { get; set; } #endregion } } |
▶ IssuePrediction.cs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
using Microsoft.ML.Data; namespace TestProject { /// <summary> /// 이슈 예측 /// </summary> public class IssuePrediction { //////////////////////////////////////////////////////////////////////////////////////////////////// Property ////////////////////////////////////////////////////////////////////////////////////////// Public #region 영역 - Area /// <summary> /// 영역 /// </summary> [ColumnName("PredictedLabel")] public string Area; #endregion } } |
▶ Program.cs
|
using System; using System.IO; using Microsoft.ML; using Microsoft.ML.Data; namespace TestProject { /// <summary> /// 프로그램 /// </summary> class Program { //////////////////////////////////////////////////////////////////////////////////////////////////// Field ////////////////////////////////////////////////////////////////////////////////////////// Static //////////////////////////////////////////////////////////////////////////////// Private #region Field /// <summary> /// 애플리케이션 경로 /// </summary> private static string _applicationPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]); /// <summary> /// 훈련 데이터 파일 경로 /// </summary> private static string _trainingDataFilePath => Path.Combine(_applicationPath, "Data", "issues_train.tsv"); /// <summary> /// 테스트 데이터 파일 경로 /// </summary> private static string _testDataFilePath => Path.Combine(_applicationPath, "Data", "issues_test.tsv"); /// <summary> /// 모델 파일 경로 /// </summary> private static string _modelFilePath => Path.Combine(_applicationPath, "Data", "model.zip"); /// <summary> /// 모델 컨텍스트 /// </summary> private static MLContext _context; /// <summary> /// 예측 엔진 /// </summary> private static PredictionEngine<IssueData, IssuePrediction> _predictionEngine; /// <summary> /// 모델 /// </summary> private static ITransformer _model; /// <summary> /// 훈련 데이터 뷰 /// </summary> private static IDataView _trainingDataView; #endregion //////////////////////////////////////////////////////////////////////////////////////////////////// Method ////////////////////////////////////////////////////////////////////////////////////////// Static //////////////////////////////////////////////////////////////////////////////// Private #region 프로그램 시작하기 - Main() /// <summary> /// 프로그램 시작하기 /// </summary> private static void Main() { Console.WriteLine("BEGIN MAIN FUNCTION"); _context = new MLContext(seed: 0); Console.WriteLine("BEGIN SET TRAINING DATA VIEW"); _trainingDataView = _context.Data.LoadFromTextFile<IssueData>(_trainingDataFilePath, hasHeader: true); Console.WriteLine("END SET TRAINING DATA VIEW"); Console.WriteLine("BEGIN SET PIPE LINE"); var pipeline = GetPipeLine(); Console.WriteLine("END SET PIPE LINE"); _model = GetModel(_trainingDataView, pipeline); Console.WriteLine("BEGIN PREDICT SINGLE ITEM"); _predictionEngine = _context.Model.CreatePredictionEngine<IssueData, IssuePrediction>(_model); IssueData issueData = new IssueData() { Title = "WebSockets communication is slow in my machine", Description = "The WebSockets communication used under the covers by SignalR looks like is going slow in my development machine.." }; IssuePrediction issuePrediction = _predictionEngine.Predict(issueData); Console.WriteLine("--------------------------------------------------"); Console.WriteLine($"AREA : {issuePrediction.Area}"); Console.WriteLine("--------------------------------------------------"); Evaluate(_trainingDataView.Schema); SaveModel(_context, _trainingDataView.Schema, _model); PredictIssue(); Console.WriteLine("END MAIN FUNCTION"); } #endregion #region 파이프 라인 구하기 - GetPipeLine() /// <summary> /// 파이프 라인 구하기 /// </summary> /// <returns>파이프 라인</returns> private static IEstimator<ITransformer> GetPipeLine() { Console.WriteLine("BEGIN GET PIPE LINE FUNCTION"); var pipeline = _context.Transforms.Conversion.MapValueToKey(inputColumnName : "Area", outputColumnName : "Label") .Append(_context.Transforms.Text.FeaturizeText(inputColumnName : "Title" , outputColumnName : "TitleFeaturized" )) .Append(_context.Transforms.Text.FeaturizeText(inputColumnName : "Description", outputColumnName : "DescriptionFeaturized")) .Append(_context.Transforms.Concatenate("Features", "TitleFeaturized", "DescriptionFeaturized")) .AppendCacheCheckpoint(_context) .Append(_context.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features")) .Append(_context.Transforms.Conversion.MapKeyToValue("PredictedLabel")); Console.WriteLine("END GET PIPE LINE FUNCTION"); return pipeline; } #endregion #region 모델 구하기 - GetModel(trainingDataView, pipeline) /// <summary> /// 모델 구하기 /// </summary> /// <param name="trainingDataView">훈련 데이터 뷰</param> /// <param name="pipeline">파이프 라인</param> /// <returns>모델</returns> private static ITransformer GetModel(IDataView trainingDataView, IEstimator<ITransformer> pipeline) { Console.WriteLine("BEGIN GET MODEL FUNCTION"); var model = pipeline.Fit(trainingDataView); Console.WriteLine("END GET MODEL FUNCTION"); return model; } #endregion #region 평가하기 - Evaluate(trainingDataViewSchema) /// <summary> /// 평가하기 /// </summary> /// <param name="trainingDataViewSchema">훈련 데이터 뷰 스키마</param> private static void Evaluate(DataViewSchema trainingDataViewSchema) { Console.WriteLine("BEGIN EVALUATE FUNCTION"); IDataView testDataView = _context.Data.LoadFromTextFile<IssueData>(_testDataFilePath, hasHeader : true); MulticlassClassificationMetrics metrics = _context.MulticlassClassification.Evaluate(_model.Transform(testDataView)); Console.WriteLine("--------------------------------------------------" ); Console.WriteLine("METRICS FOR MULTI-CLASS CLASSIFICATION MODEL - TEST DATA"); Console.WriteLine("--------------------------------------------------" ); Console.WriteLine($"MICRO ACCURACY : {metrics.MicroAccuracy:0.###}" ); Console.WriteLine($"MACRO ACCURACY : {metrics.MacroAccuracy:0.###}" ); Console.WriteLine($"LOG LOSS : {metrics.LogLoss:#.###}" ); Console.WriteLine($"LOG LOSS REDUCTION : {metrics.LogLossReduction:#.###}" ); Console.WriteLine("--------------------------------------------------" ); Console.WriteLine("END EVALUATE FUNCTION"); } #endregion #region 모델 저장하기 - SaveModel(context, trainingDataViewSchema, model) /// <summary> /// 모델 저장하기 /// </summary> /// <param name="context">ML 컨텍스트</param> /// <param name="trainingDataViewSchema">훈련 데이터 뷰 스키마</param> /// <param name="model">모델</param> private static void SaveModel(MLContext context, DataViewSchema trainingDataViewSchema, ITransformer model) { Console.WriteLine("BEGIN SAVE MODEL FUNCTION"); context.Model.Save(model, trainingDataViewSchema, _modelFilePath); Console.WriteLine("END SAVE MODEL FUNCTION"); } #endregion #region 이슈 예측하기 - PredictIssue() /// <summary> /// 이슈 예측하기 /// </summary> private static void PredictIssue() { Console.WriteLine("BEGIN PREDICT ISSUE FUNCTION"); ITransformer model = _context.Model.Load(_modelFilePath, out var inputSchema); IssueData issueData = new IssueData() { Title = "Entity Framework crashes", Description = "When connecting to the database, EF is crashing" }; _predictionEngine = _context.Model.CreatePredictionEngine<IssueData, IssuePrediction>(model); IssuePrediction issuePrediction = _predictionEngine.Predict(issueData); Console.WriteLine("--------------------------------------------------"); Console.WriteLine($"AREA : {issuePrediction.Area}"); Console.WriteLine("--------------------------------------------------"); Console.WriteLine("END PREDICT ISSUE FUNCTION"); } #endregion } } |