-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtext-classifier.py
More file actions
120 lines (105 loc) · 3.19 KB
/
text-classifier.py
File metadata and controls
120 lines (105 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Lobang Club Dev Challenge Submission by Ray Dino
# Version 6
# May 18 , 2012
from sys import argv
import csv
import math
def unicode_csv_reader(utf8_data, dialect=csv.excel, **kwargs):
csv_reader = csv.reader(utf8_data, dialect=dialect, **kwargs)
for row in csv_reader:
yield [cell for cell in row]
def main():
tokens = " ,'\":~`!@#$%^&*()-_+={}[]\\/?>.<|1234567890"
minlen = 1
blacklist = ["AND", "IN", "FOR", "THE", "A", "AN", "MY"]
if (len(argv)!=4):
print "Usage: %s <categories> <training set> <competition set>" % (argv[0])
exit()
category_arg = argv[1]
training_arg = argv[2]
competition_arg = argv[3]
categories = {}
origcategories = {}
labelno = 0
termcatcount = {}
termcount = {}
categories_file = open(category_arg, 'rU')
reader = unicode_csv_reader(categories_file)
for row in reader:
cat = row[0].upper().strip()
if cat not in categories:
origcategories[cat] = row[0].strip()
categories[cat] = 0
categories[cat]+=1
labelno+=1
terms = str.split(row[0])
for term in terms:
uterm = term.upper().strip(tokens)
if len(uterm) > minlen and uterm not in blacklist:
if uterm not in termcount:
termcount[uterm] = 1
termcount[uterm]+=1
if (uterm, cat) not in termcatcount:
termcatcount[(uterm, cat)] = 1
termcatcount[(uterm, cat)]+=1
training_file = open(training_arg, 'rU')
reader = unicode_csv_reader(training_file)
reader.next()
for row in reader:
cat = row[1].upper().strip()
if cat not in categories:
origcategories[cat] = row[1].strip()
categories[cat] = 0
categories[cat]+=1
labelno+=1
terms = str.split(row[0])
for term in terms:
uterm = term.upper().strip(tokens)
if len(uterm) > minlen and uterm not in blacklist:
if uterm not in termcount:
termcount[uterm] = 1
termcount[uterm]+=1
if (uterm, cat) not in termcatcount:
termcatcount[(uterm, cat)] = 1
termcatcount[(uterm, cat)]+=1
result_name = "results.csv"
result_file = open(result_name, "w")
competition_file = open(competition_arg, 'rU')
reader = unicode_csv_reader(competition_file)
for row in reader:
prob = {}
found = False
terms = str.split(row[0])
for term in terms:
uterm = term.upper().strip(tokens)
if uterm in termcount:
found = True
break;
if found:
for category in categories:
cat = category.upper()
logprob = 0
terms = str.split(row[0])
for term in terms:
uterm = term.upper().strip(tokens)
if len(uterm) > minlen and uterm not in blacklist:
if (uterm, cat) not in termcatcount:
termcatcount[(uterm,cat)]=1
if uterm not in termcount:
termcount[uterm] = 1
logprob += math.log(float(termcatcount[(uterm, cat)])/float(termcount[uterm]))
prob[cat] = math.exp(logprob)
key = ""
maxn = 0.0
for p in prob:
if prob[p] > maxn:
maxn = prob[p]
key = p
keys = [key]
for p in prob:
if key!=p and ((prob[key]-(prob[key]*0.3))<=prob[p]):
keys.append(p)
for k in keys:
result_file.write("%s,%s\n" % (row[0], origcategories[k]));
print "%s result file created!" % result_name
main()