Betrouwbaarheidsintervallen berekenen voor prestatiestatistieken in Machine Learning met behulp van een automatische bootstrap-methode

Bronknooppunt: 1178501

Betrouwbaarheidsintervallen berekenen voor prestatiestatistieken in Machine Learning met behulp van een automatische bootstrap-methode

Zijn de prestatiemetingen van uw model zeer nauwkeurig vanwege een ‘grote’ testset, of zeer onzeker vanwege een ‘kleine’ of onevenwichtige testset?


By David B. Rosen (PhD), Lead Data Scientist voor geautomatiseerde kredietgoedkeuring bij IBM Global Financing



De oranje lijn toont 89.7% als de ondergrens van het betrouwbaarheidsinterval voor evenwichtige nauwkeurigheid, groen voor de oorspronkelijk waargenomen evenwichtige nauwkeurigheid = 92.4% (puntschatting) en rood voor de bovengrens van 94.7%. (Dit en alle afbeeldingen zijn van de auteur, tenzij anders vermeld.)

Introductie

 
 
Als u rapporteert dat de prestaties van uw classificator Nauwkeurigheid = 94.8% en F1 = 92.3% hebben op een testset, betekent dit niet veel zonder iets te weten over de omvang en samenstelling van de testset. De foutmarge van deze prestatiemetingen zal sterk variëren, afhankelijk van de grootte van de testset, of, voor een onevenwichtige dataset, voornamelijk afhankelijk van het aantal onafhankelijke exemplaren van de testset. minderheid klasse die het bevat (meer kopieën van dezelfde instanties door oversampling helpt niet voor dit doel).

Als u een andere, onafhankelijke testset van vergelijkbare oorsprong zou kunnen verzamelen, is het onwaarschijnlijk dat de nauwkeurigheid en F1 van uw model op deze dataset hetzelfde zullen zijn, maar hoeveel verschillend zouden ze plausibel kunnen zijn? Een soortgelijke vraag wordt in de statistieken beantwoord als de Betrouwbaarheidsinterval van de meting.

Als we veel onafhankelijke steekproefdatasets uit de onderliggende populatie zouden trekken, dan zou voor 95% van die datasets de werkelijke onderliggende populatiewaarde van de metriek binnen het betrouwbaarheidsinterval van 95% liggen dat we voor die specifieke steekproefdataset zouden berekenen.

In dit artikel laten we u zien hoe u betrouwbaarheidsintervallen voor een willekeurig aantal Machine Learning-prestatiestatistieken tegelijk kunt berekenen, met een bootstrap-methode die webmaster. bepaalt hoeveel opstartvoorbeeldgegevenssets er standaard moeten worden gegenereerd.

Als u alleen maar wilt zien hoe u deze code kunt aanroepen om betrouwbaarheidsintervallen te berekenen, ga dan naar de sectie “Bereken de resultaten!” beneden.

De bootstrap-methodologie

 
 
Als we aanvullende testdatasets zouden kunnen halen uit de werkelijke verdeling die ten grondslag ligt aan de gegevens, zouden we de verdeling van de prestatiestatistiek(en) over die datasets kunnen zien. (Bij het tekenen van deze datasets zouden we niets doen om te voorkomen dat een identieke of soortgelijke instantie meerdere keren wordt getekend, hoewel dit slechts zelden zou kunnen gebeuren.)

Omdat we dat niet kunnen doen, is het beste wat we kunnen doen om aanvullende datasets uit de empirische distributie van deze testgegevensset, wat betekent dat er monsters worden genomen, met vervanging, van de instanties ervan om nieuwe bootstrap-voorbeeldgegevenssets te genereren. Bemonstering met vervanging betekent dat zodra we een bepaald exemplaar hebben getekend, we het er weer in stoppen, zodat we het opnieuw kunnen tekenen voor dezelfde voorbeeldgegevensset. Daarom heeft elke dergelijke dataset over het algemeen meerdere kopieën van sommige exemplaren, en omvat deze niet alle exemplaren die zich in de basistestset bevinden.

Als we proefden zonder vervanging, dan zouden we elke keer eenvoudigweg een identieke kopie van de originele dataset krijgen, geschud in een andere willekeurige volgorde, wat van geen enkel nut zou zijn.

De percentiel bootstrap-methodologie voor het schatten van het betrouwbaarheidsinterval is als volgt:

  1. Genereer nboots ‘bootstrap sample’-datasets, elk van dezelfde grootte als de originele testset. Elke voorbeelddataset wordt verkregen door willekeurig exemplaren uit de testset te trekken met vervanging.
  2. Bereken voor elk van de voorbeeldgegevenssets de metriek en sla deze op.
  3. Het 95% betrouwbaarheidsinterval wordt gegeven door de 2.5th naar de 97.5th percentiel onder de nboots berekende waarden van de metriek. Als nboots=1001 en je hebt de waarden gesorteerd in een reeks/matrix/lijst X met lengte 1001, de 0th percentiel is X[0] en de 100th percentiel is X[1000], dus het betrouwbaarheidsinterval zou worden gegeven door X[25] naar X[975].

Natuurlijk kunt u in stap 2 voor elke voorbeeldgegevensset zoveel statistieken berekenen als u wilt, maar in stap 3 vindt u de percentielen voor elke metriek afzonderlijk.

Voorbeeldgegevensset en betrouwbaarheidsintervalresultaten

 
 
We zullen de resultaten uit dit eerdere artikel als voorbeeld gebruiken: Hoe om te gaan met onevenwichtige classificatie, zonder de gegevens opnieuw in evenwicht te brengen?Voordat u overweegt uw scheve gegevens te overbemonsteren, kunt u proberen de drempelwaarde voor uw classificatiebeslissing aan te passen.

In dat artikel gebruikten we de zeer-onevenwichtige Kaggle met twee klassen identificatiegegevensset voor creditcardfraude. We hebben ervoor gekozen om een ​​classificatiedrempel te gebruiken die heel anders is dan de standaarddrempel van 0.5 die impliciet is bij het gebruik van de voorspellen()-methode, waardoor het niet nodig is om de gegevens in evenwicht te brengen. Deze aanpak wordt soms genoemd drempel beweegt, waarin onze classificator de klasse toewijst door de gekozen drempelwaarde toe te passen op de voorspelde klassewaarschijnlijkheid die door de voorspelling wordt geboden_waarschijnlijk() methode.

We zullen de reikwijdte van dit artikel (en de code) beperken tot binaire classificatie: klassen 0 en 1, waarbij klasse 1 volgens afspraak de ‘positieve’ klasse is en specifiek de minderheidsklasse voor onevenwichtige gegevens, hoewel de code zou moeten werken voor regressie (enkelvoudige continu doel) ook.

Eén opstartvoorbeeldgegevensset genereren

 
 
Hoewel onze betrouwbaarheidsintervalcode verschillende aantallen gegevensargumenten kan verwerken die aan de metrische functies moeten worden doorgegeven, zullen we ons concentreren op metrieken in sklearn-stijl, die altijd twee gegevensargumenten accepteren, y_true en y_pred, waarbij y_pred voorspellingen van binaire klassen zal zijn (0 of 1), of continue klasse-waarschijnlijkheids- of beslissingsfunctie-voorspellingen, of zelfs continue regressie-voorspellingen als y_true ook continu is. De volgende functie genereert een enkele opstartvoorbeeldgegevensset. Het accepteert alle data_args, maar in ons geval zullen deze argumenten dat wel zijn ytest(onze werkelijke/echte testset streefwaarden in de vorig artikel) en hardpredtst_tuned_thresh (de voorspelde klasse). Beide bevatten nullen en enen om de ware of voorspelde klasse voor elke instantie aan te geven.

Aangepaste metrische specificity_score() en hulpprogrammafuncties

 
 
We zullen een aangepaste metrische functie definiëren voor Specificiteit, wat gewoon een andere naam is voor het terugroepen van de negatief klasse (klasse 0). Ook een calc_metrics-functie die een reeks interessante statistieken toepast op onze gegevens, en een aantal hulpfuncties daarvoor:

Hier maken we onze lijst met statistieken en passen deze toe op de gegevens. We beschouwen Nauwkeurigheid niet als een relevante maatstaf, omdat een vals-negatief (een echte fraude verkeerd als legitiem classificeren) veel duurder is voor het bedrijf dan een vals-positief (een echte legitieme fraude verkeerd classificeert als fraude), terwijl Nauwkeurigheid beide soorten misclassificatie behandelt zijn even slecht en geven daarom de voorkeur aan het correct classificeren van degenen wier echte klasse de meerderheidsklasse is, omdat deze veel vaker voorkomen en dus veel meer bijdragen aan de algehele nauwkeurigheid.

met=[ metrics.recall_score, specificity_score, metrics.balanced_accuracy_score ]
calc_metrics(met, ytest, hardpredtst_tuned_thresh)



Het maken van elke opstartvoorbeeldgegevensset en het berekenen van de statistieken daarvoor

 
 
In raw_metric_samples() zullen we feitelijk meerdere voorbeelddatasets één voor één genereren en de statistieken van elke dataset opslaan:

U geeft raw_metric_samples() een lijst met metrieken (of slechts één metriek) van belang, evenals de ware en voorspelde klassegegevens, en het verkrijgt nboots voorbeelddatasets en retourneert een dataframe met alleen de waarden van de metrieken berekend op basis van elke dataset. Via _boot_generator() roept het one_boot() één voor één aan in een generatorexpressie in plaats van alle datasets tegelijk op te slaan als een potentieel-reusachtig lijst.

Bekijk de statistieken van zeven opstartvoorbeelddatasets

 
 
We maken onze lijst met metrische functies en roepen raw_metric_samples() aan om de resultaten voor slechts zeven voorbeeldgegevenssets te krijgen. We roepen hier raw_metric_samples() aan voor een beter begrip - het is niet nodig om betrouwbaarheidsintervallen te krijgen met behulp van ci_auto() hieronder, hoewel we een lijst met metrieken (of slechts één metriek) specificeren voor ci_auto() is noodzakelijk.

np.random.seed(13)
raw_metric_samples(met, ytest, hardpredtst_tuned_thresh, nboots=7).style.format('{:.2%}') #optional #style



Elke kolom hierboven bevat de statistieken die zijn berekend op basis van één opstartvoorbeeldgegevensset (genummerd van 0 tot en met 6), dus de berekende metrische waarden variëren als gevolg van de willekeurige steekproeven.

Aantal opstartgegevenssets, met berekende standaardwaarde

 
 
In onze implementatie is dit standaard het aantal opstartgegevenssets nboots wordt automatisch berekend op basis van het gewenste betrouwbaarheidsniveau (bijvoorbeeld 95%) om aan de aanbeveling te voldoen Noord, Curtis en Sham om een ​​minimum aantal opstartresultaten in elke staart van de verdeling te hebben. (Eigenlijk is deze aanbeveling van toepassing op p-waarden en dus hypothesetest acceptatiegebieden, Maar betrouwbaarheidsintervallen zijn vergelijkbaar genoeg met die om dit als vuistregel te gebruiken.) Hoewel deze auteurs een minimum van 10 opstartresultaten in de staart aanbevelen, Davidson & MacKinnon raden ten minste 399 laarzen aan voor 95% betrouwbaarheid, waarvoor 11 laarzen in de staart nodig zijn, dus we gebruiken deze meer conservatieve aanbeveling.

We specificeren alfa, namelijk 1 – betrouwbaarheidsniveau. Een betrouwbaarheid van 95% wordt bijvoorbeeld 0.95 en alpha=0.05. Als u een expliciet aantal boots opgeeft (misschien een kleiner nboots omdat je snellere resultaten wilt) maar het is niet genoeg voor de door jou gevraagde alpha, er wordt automatisch een hogere alpha gekozen om een ​​nauwkeurig betrouwbaarheidsinterval voor dat aantal boots te krijgen. Er zullen minimaal 51 schoenen worden gebruikt, omdat met minder slechts bizar kleine betrouwbaarheidsniveaus nauwkeurig kunnen worden berekend (zoals een betrouwbaarheid van 40%, wat een interval oplevert vanaf de 30th percentiel tot 70th percentiel, dat 40% binnen het interval heeft, maar 60% daarbuiten) en het is niet duidelijk dat de aanbeveling voor minimale laarzen zelfs maar een dergelijk geval in overweging nam.

De functie get_alpha_nboots() stelt de standaard nboots in of wijzigt de gevraagde alpha en nboots zoals hierboven:

Laten we de standaard nboots tonen voor verschillende waarden van alpha:

g = get_alpha_nboots pd.DataFrame( [ g(0.40), g(0.20, None), g(0.10), g(), g(alpha=0.02), g(alpha=0.01, nboots=None), g(0.005, nboots=None) ], columns=['alpha', 'default nboots'] ).set_index('alpha')



Dit is wat er gebeurt als we een expliciete nboots aanvragen:

req=[(0.01,3000), (0.01,401), (0.01,2)]
out=[get_alpha_nboots(*args) for args in req]
mydf = lambda x: pd.DataFrame(x, columns=['alpha', 'nboots'])
pd.concat([mydf(req),mydf(out)],axis=1, keys=('Requested','Using'))



Kleine nboots-waarden verhoogden alfa naar 0.05 en 0.40, en nboots=2 werd gewijzigd naar het minimum van 51.

Histogram van bootstrap-voorbeeldgegevenssets die het betrouwbaarheidsinterval tonen, alleen voor gebalanceerde nauwkeurigheid

 
 
Nogmaals, we hoeven dit niet te doen om de onderstaande betrouwbaarheidsintervallen te krijgen door ci_auto() aan te roepen.

np.random.seed(13)
metric_boot_histogram (metrics.balanced_accuracy_score, ytest, hardpredtst_tuned_thresh)



De oranje lijn toont 89.7% als de ondergrens van het gebalanceerde nauwkeurigheidsbetrouwbaarheidsinterval, groen voor de oorspronkelijk waargenomen gebalanceerde nauwkeurigheid = 92.4% (puntschatting) en rood voor de bovengrens van 94.7%. (Dezelfde afbeelding verschijnt bovenaan dit artikel.)

Hoe u alle betrouwbaarheidsintervallen voor de lijst met statistieken kunt berekenen

 
 
Hier is de hoofdfunctie die het bovenstaande aanroept en de betrouwbaarheidsintervallen berekent op basis van de percentielen van de metrische resultaten, en de puntschattingen invoegt als de eerste kolom van het uitvoerdataframe met resultaten.

Bereken de resultaten!

 
 
Dit is alles wat we echt hoefden te doen: roep ci_auto() als volgt aan met een lijst met statistieken (met hierboven toegewezen) om hun betrouwbaarheidsintervallen te verkrijgen. De percentageopmaak is optioneel:

np.random.seed(13)
ci_auto( met, ytest, hardpredtst_tuned_thresh ).style.format('{:.2%}')



Bespreking van de resulterende betrouwbaarheidsintervallen

 
 
Hier is de verwarringsmatrix van de originele artikel. Klasse 0 is de negatieven (meerderheidsklasse) en Klasse 1 is de positieven (zeer zeldzame klasse)



De Recall (True Positive Rate) van 134/(134+14) heeft het breedste betrouwbaarheidsinterval omdat dit een binominale verhouding is waarbij kleine tellingen betrokken zijn.

De specificiteit (werkelijk negatief percentage) is 80,388/(80,388+4,907), wat inhoudt veel grotere aantallen, dus het heeft een extreem smal betrouwbaarheidsinterval van slechts [94.11% tot 94.40%].

Omdat de Evenwichtige Nauwkeurigheid eenvoudigweg wordt berekend als een gemiddelde van de Recall en de Specificiteit, ligt de breedte van het betrouwbaarheidsinterval tussen die van hen in.

Metrische onnauwkeurigheid van metingen als gevolg van variaties in testgegevens versus variaties in treingegevens

 
 
Hier hebben we geen rekening gehouden met de variabiliteit in de model gebaseerd op de willekeur van onze opleiding gegevens (hoewel dat voor sommige doeleinden ook van belang kan zijn, bijvoorbeeld als u herhaalde herscholing geautomatiseerd heeft en wilt weten hoeveel de prestaties van toekomstige modellen kunnen variëren), maar eerder alleen de variabiliteit in de meting van de prestaties hiervan bijzonder model (gemaakt op basis van bepaalde trainingsgegevens) vanwege de willekeur van onze proef data.

Als we voldoende onafhankelijke testgegevens hadden, zouden we de prestaties van dit specifieke model op de onderliggende populatie zeer nauwkeurig kunnen meten, en zouden we weten hoe het zal presteren als dit model wordt ingezet, ongeacht hoe we het model hebben gebouwd en of we het een beter of slechter model verkrijgen met een andere trainingsvoorbeelddataset.

Onafhankelijkheid van individuele gevallen

 
 
De bootstrap-methode gaat ervan uit dat elk van uw gevallen (gevallen, observaties) onafhankelijk van een onderliggende populatie wordt getrokken. Als uw testset groepen rijen bevat die niet onafhankelijk van elkaar zijn, bijvoorbeeld herhaalde waarnemingen van dezelfde entiteit die waarschijnlijk met elkaar gecorreleerd zijn, of instanties die overbemonsterd/gerepliceerd/gegenereerd zijn uit andere instanties in uw test ingesteld, zijn de resultaten mogelijk niet geldig. Mogelijk moet u gebruiken gegroepeerd sampling, waarbij u hele groepen willekeurig bij elkaar trekt in plaats van individuele rijen, terwijl u vermijdt dat u een groep opsplitst of slechts een deel ervan gebruikt.

Je wilt er ook voor zorgen dat je geen groepen hebt die verdeeld zijn over de trainings- en testset, omdat de testset dan niet noodzakelijkerwijs onafhankelijk is en je ongemerkt overfitting kunt krijgen. Als u bijvoorbeeld oversampling gebruikt, moet u dit doorgaans alleen doen na het is afgesplitst van de testset, niet eerder. En normaal gesproken zou u de trainingsset oversamplen, maar niet de testset, omdat de testset representatief moet blijven voor de instanties die het model zal zien bij toekomstige implementatie. En voor kruisvalidatie zou je scikit-learn's willen gebruiken model_selection.GroupKFold().

Conclusie

 
 
U kunt altijd betrouwbaarheidsintervallen voor uw evaluatiestatistiek(en) berekenen om te zien hoe nauwkeurig uw testgegevens u in staat stellen de prestaties van uw model te meten. Ik ben van plan nog een artikel te schrijven om betrouwbaarheidsintervallen te demonstreren voor metrieken die waarschijnlijkheidsvoorspellingen evalueren (of betrouwbaarheidsscores – geen relatie met statistische betrouwbaarheid), dat wil zeggen zachte classificatie, zoals Log Loss of ROC AUC, in plaats van de metrieken die we hier hebben gebruikt en die de discrete klassekeuze door het model (harde classificatie). Dezelfde code werkt voor beide, maar ook voor regressie (het voorspellen van een continue doelvariabele) – je hoeft er alleen maar een ander soort voorspelling aan door te geven (en een ander soort echte doelen in het geval van regressie).

Deze jupyter-notebook is beschikbaar in github: bootConfIntAutoV1o_standalone.ipynb

Was dit artikel informatief en/of nuttig? Plaats hieronder een reactie als u opmerkingen of vragen heeft over dit artikel of over betrouwbaarheidsintervallen, de bootstrap, het aantal boots, deze implementatie, dataset, model, drempelverplaatsing of resultaten.

Naast het bovengenoemde vorig artikel, misschien ben je ook geïnteresseerd in de mijne Hoe de datum/datum/tijd-kolommen automatisch te detecteren en hun gegevenstype in te stellen bij het lezen van een CSV-bestand in Panda's, hoewel het niet direct verband houdt met het huidige artikel.

Sommige rechten voorbehouden

 
Bio: David B. Rosen (PhD) is Lead Data Scientist voor Automated Credit Approval bij IBM Global Financing. Vind meer van David's schrijven op dabruro.medium.com.

ORIGINELE. Met toestemming opnieuw gepost.

Zie ook:

Bron: https://www.kdnuggets.com/2021/10/calculate-confidence-intervals-performance-metrics-machine-learning.html

Tijdstempel:

Meer van KDnuggets