Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:24:07

0001 #!/usr/bin/env python3
0002 # from https://gist.github.com/IevaZarina/ef63197e089169a9ea9f3109058a9679
0003 
0004 from __future__ import print_function
0005 import numpy as np
0006 import xgboost as xgb
0007 from sklearn import datasets
0008 from sklearn.model_selection import train_test_split
0009 from sklearn.datasets import dump_svmlight_file
0010 import joblib
0011 from sklearn.metrics import precision_score
0012 
0013 iris = datasets.load_iris()
0014 X = iris.data
0015 y = iris.target
0016 
0017 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
0018 
0019 # use DMatrix for xgbosot
0020 dtrain = xgb.DMatrix(X_train, label=y_train)
0021 dtest = xgb.DMatrix(X_test, label=y_test)
0022 
0023 # use svmlight file for xgboost
0024 dump_svmlight_file(X_train, y_train, 'dtrain.svm', zero_based=True)
0025 dump_svmlight_file(X_test, y_test, 'dtest.svm', zero_based=True)
0026 dtrain_svm = xgb.DMatrix('dtrain.svm')
0027 dtest_svm = xgb.DMatrix('dtest.svm')
0028 
0029 # set xgboost params
0030 param = {
0031     'max_depth': 3,  # the maximum depth of each tree
0032     'eta': 0.3,  # the training step for each iteration
0033     'silent': 1,  # logging mode - quiet
0034     'objective': 'multi:softprob',  # error evaluation for multiclass training
0035     'num_class': 3}  # the number of classes that exist in this datset
0036 num_round = 20  # the number of training iterations
0037 
0038 #------------- numpy array ------------------
0039 # training and testing - numpy matrices
0040 bst = xgb.train(param, dtrain, num_round)
0041 preds = bst.predict(dtest)
0042 
0043 # extracting most confident predictions
0044 best_preds = np.asarray([np.argmax(line) for line in preds])
0045 print("Numpy array precision:", precision_score(y_test, best_preds, average='macro'))
0046 
0047 # ------------- svm file ---------------------
0048 # training and testing - svm file
0049 bst_svm = xgb.train(param, dtrain_svm, num_round)
0050 preds = bst.predict(dtest_svm)
0051 
0052 # extracting most confident predictions
0053 best_preds_svm = [np.argmax(line) for line in preds]
0054 print("Svm file precision:",precision_score(y_test, best_preds_svm, average='macro'))
0055 # --------------------------------------------
0056 
0057 # dump the models
0058 bst.dump_model('dump.raw.txt')
0059 bst_svm.dump_model('dump_svm.raw.txt')
0060 
0061 
0062 # save the models for later
0063 joblib.dump(bst, 'bst_model.pkl', compress=True)
0064 joblib.dump(bst_svm, 'bst_svm_model.pkl', compress=True)