Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-11-25 02:29:52

0001 #!/usr/bin/env python3
0002 # from https://gist.github.com/IevaZarina/ef63197e089169a9ea9f3109058a9679
0003 
0004 import numpy as np
0005 import xgboost as xgb
0006 from sklearn import datasets
0007 from sklearn.model_selection import train_test_split
0008 from sklearn.datasets import dump_svmlight_file
0009 import joblib
0010 from sklearn.metrics import precision_score
0011 
0012 iris = datasets.load_iris()
0013 X = iris.data
0014 y = iris.target
0015 
0016 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
0017 
0018 # use DMatrix for xgbosot
0019 dtrain = xgb.DMatrix(X_train, label=y_train)
0020 dtest = xgb.DMatrix(X_test, label=y_test)
0021 
0022 # use svmlight file for xgboost
0023 dump_svmlight_file(X_train, y_train, 'dtrain.svm', zero_based=True)
0024 dump_svmlight_file(X_test, y_test, 'dtest.svm', zero_based=True)
0025 dtrain_svm = xgb.DMatrix('dtrain.svm')
0026 dtest_svm = xgb.DMatrix('dtest.svm')
0027 
0028 # set xgboost params
0029 param = {
0030     'max_depth': 3,  # the maximum depth of each tree
0031     'eta': 0.3,  # the training step for each iteration
0032     'silent': 1,  # logging mode - quiet
0033     'objective': 'multi:softprob',  # error evaluation for multiclass training
0034     'num_class': 3}  # the number of classes that exist in this datset
0035 num_round = 20  # the number of training iterations
0036 
0037 #------------- numpy array ------------------
0038 # training and testing - numpy matrices
0039 bst = xgb.train(param, dtrain, num_round)
0040 preds = bst.predict(dtest)
0041 
0042 # extracting most confident predictions
0043 best_preds = np.asarray([np.argmax(line) for line in preds])
0044 print("Numpy array precision:", precision_score(y_test, best_preds, average='macro'))
0045 
0046 # ------------- svm file ---------------------
0047 # training and testing - svm file
0048 bst_svm = xgb.train(param, dtrain_svm, num_round)
0049 preds = bst.predict(dtest_svm)
0050 
0051 # extracting most confident predictions
0052 best_preds_svm = [np.argmax(line) for line in preds]
0053 print("Svm file precision:",precision_score(y_test, best_preds_svm, average='macro'))
0054 # --------------------------------------------
0055 
0056 # dump the models
0057 bst.dump_model('dump.raw.txt')
0058 bst_svm.dump_model('dump_svm.raw.txt')
0059 
0060 
0061 # save the models for later
0062 joblib.dump(bst, 'bst_model.pkl', compress=True)
0063 joblib.dump(bst_svm, 'bst_svm_model.pkl', compress=True)