본문 바로가기

코딩/스터디

[코드 구현] logistic regression

numpy 로 구현

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class LogisticRegression:
    def __init__(self, lr=0.01, num_iter=1000, fit_intercept=True, verbose=False):
        self.lr = lr
        self.num_iter = num_iter
        self.fit_intercept = fit_intercept
        self.verbose = verbose
        self.eps = 1e-10
        self.threshold = 0.5
        self.loss_history = list()
    
    def __add_intercept(self, X):
        intercept = np.ones((X.shape[0], 1))
        return np.concatenate((intercept, X), axis=1)
    
    def __sigmoid(self, z):
        return 1 / (1 + np.exp(-z))
        
    def __loss(self, h, y):
        return (-* np.log(h+ self.eps) - (1 - y) * np.log(1 - h + self.eps)).mean()
    
    def fit(self, X, y):
        if self.fit_intercept:
            X = self.__add_intercept(X)
        
        # weights initialization
        self.theta = np.zeros(X.shape[1])
        
        for i in range(self.num_iter):
            logit = np.dot(X, self.theta)
            hypothesis = self.__sigmoid(logit)
            gradient = np.dot(X.T, (hypothesis - y)) / y.size
            self.theta -= self.lr * gradient
            
            if self.verbose == True and i % 10 == 0:
                loss = self.__loss(hypothesis, y)
                print(f'epoch: {i} \t loss: {loss} \t')
                self.loss_history.append(loss)
        return self.loss_history
    
    def predict_prob(self, X):
        if self.fit_intercept:
            X = self.__add_intercept(X)
    
        return self.__sigmoid(np.dot(X, self.theta))
    
    def predict(self, X):
        predicted_labels = np.where(self.predict_prob(X) > self.threshold, 10)
        return predicted_labels
 
    def eval(self, x, y):
        res_y = np.round(self.predict_prob(x), 0)
        accuracy = np.sum(res_y==y) / len(y)
 
        return accuracy
cs

'코딩 > 스터디' 카테고리의 다른 글

[RestNET] 1. 웹 서버 구축  (0) 2023.05.25
서버 jupyter notebook port 설정 변경  (0) 2023.05.25
[자바 스크립트] 최신 트렌드  (0) 2020.09.25
[PLC] 기초 이론  (0) 2020.09.22
2진수  (0) 2020.09.22